mirror of
https://github.com/langgenius/dify.git
synced 2026-05-19 16:27:17 +08:00
Compare commits
52 Commits
deploy/dev
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| e96b18acf2 | |||
| b85cdc37f5 | |||
| 1925d58369 | |||
| b79fc5d6b4 | |||
| 6649e4025e | |||
| b96f372f45 | |||
| 127fbf2c9a | |||
| 3c70d28064 | |||
| cd4d6f8a22 | |||
| 9d0906c684 | |||
| 41b6f894c0 | |||
| e7e6fe8813 | |||
| c0bdd6792f | |||
| 27b084c4d4 | |||
| 3f7a68fc77 | |||
| a252fbddfa | |||
| ff02636a4b | |||
| 63946d829e | |||
| cdcfd2ef2c | |||
| b04a3851cc | |||
| b41338cd08 | |||
| 28153df4d3 | |||
| 3bc3386535 | |||
| 7654f14241 | |||
| 194b54bae4 | |||
| 0e16d36edb | |||
| 432a6412a3 | |||
| 55d05fe52d | |||
| 0d500e6965 | |||
| 5798610f27 | |||
| a35b28dbef | |||
| 1a4288c811 | |||
| 9dc32f2318 | |||
| 7210f856c9 | |||
| ebcc1200a3 | |||
| e660d7af38 | |||
| d9ccfcbc6e | |||
| a9bcec013f | |||
| aeb7687e2c | |||
| 9355d36718 | |||
| a03ee828a3 | |||
| 7066372892 | |||
| 55f95dbc36 | |||
| 8b40de3c4e | |||
| af4b9bfa8f | |||
| b9e3130388 | |||
| 12d33652b6 | |||
| fe8cf2aff4 | |||
| d1d190374d | |||
| e1be4e6aa8 | |||
| 301a470e7a | |||
| 91251ad5a5 |
709
.github/scripts/reset-test-env.sh
vendored
709
.github/scripts/reset-test-env.sh
vendored
@ -1,709 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -Eeuo pipefail
|
||||
|
||||
SCRIPT_NAME="$(basename "$0")"
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"
|
||||
DEFAULT_REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd -P)"
|
||||
|
||||
REPO_ROOT="${DIFY_REPO_ROOT:-$DEFAULT_REPO_ROOT}"
|
||||
YES=false
|
||||
DRY_RUN=true
|
||||
SKIP_SMOKE=false
|
||||
SKIP_MIGRATION=false
|
||||
TIMEOUT_SECONDS="${DIFY_RESET_TIMEOUT_SECONDS:-300}"
|
||||
SMOKE_URL="${DIFY_RESET_SMOKE_URL:-}"
|
||||
LOCK_DIR=""
|
||||
CURRENT_PHASE="init"
|
||||
|
||||
DELETED_PATHS=()
|
||||
SKIPPED_PATHS=()
|
||||
DELETED_NAMED_VOLUMES=()
|
||||
SKIPPED_NAMED_VOLUMES=()
|
||||
PRESERVED_PATHS=()
|
||||
HEALTH_RESULTS=()
|
||||
SMOKE_RESULT="not-run"
|
||||
START_TIME="$(date +%s)"
|
||||
|
||||
RUNTIME_PATHS=(
|
||||
"volumes/db/data"
|
||||
"volumes/mysql/data"
|
||||
"volumes/redis/data"
|
||||
"volumes/app/storage"
|
||||
"volumes/plugin_daemon"
|
||||
"volumes/weaviate"
|
||||
"volumes/qdrant"
|
||||
"volumes/pgvector"
|
||||
"volumes/pgvecto_rs"
|
||||
"volumes/chroma"
|
||||
"volumes/milvus"
|
||||
"volumes/opensearch/data"
|
||||
)
|
||||
|
||||
NAMED_VOLUMES=(
|
||||
"dify_es01_data"
|
||||
)
|
||||
|
||||
PRESERVE_PATHS=(
|
||||
".env"
|
||||
"middleware.env"
|
||||
"docker-compose.yaml"
|
||||
"docker-compose.middleware.yaml"
|
||||
"nginx"
|
||||
"ssrf_proxy"
|
||||
"volumes/certbot"
|
||||
"volumes/opensearch/opensearch_dashboards.yml"
|
||||
"nginx/ssl"
|
||||
)
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage: $SCRIPT_NAME [options]
|
||||
|
||||
Safely reset a Dify test environment in place. The command defaults to dry-run.
|
||||
|
||||
Options:
|
||||
--yes Perform destructive reset. Required to delete data.
|
||||
--dry-run Print planned actions without changing services or data.
|
||||
--repo-root PATH Repository root. Defaults to auto-detected Dify root.
|
||||
--smoke-url URL Public URL to verify after restart.
|
||||
--skip-smoke Skip public-domain smoke verification.
|
||||
--skip-migration Skip explicit migration gate.
|
||||
--timeout SECONDS Health-check timeout. Default: $TIMEOUT_SECONDS.
|
||||
-h, --help Show this help.
|
||||
|
||||
Required for destructive reset:
|
||||
ALLOW_DIFY_TEST_RESET=true
|
||||
DIFY_ENV_NAME=test
|
||||
|
||||
Optional:
|
||||
DIFY_RESET_SMOKE_URL=https://test.example.com
|
||||
RESET_TARGET_DOMAIN=test.example.com
|
||||
EOF
|
||||
}
|
||||
|
||||
log() {
|
||||
printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"
|
||||
}
|
||||
|
||||
fail() {
|
||||
local message="$1"
|
||||
print_report "failure"
|
||||
printf 'ERROR: %s\n' "$message" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
run_cmd() {
|
||||
printf '+'
|
||||
printf ' %q' "$@"
|
||||
printf '\n'
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
set +e
|
||||
"$@"
|
||||
local status=$?
|
||||
set -e
|
||||
if [ "$status" -ne 0 ]; then
|
||||
fail "Command failed with exit code $status: $(command_string "$@")"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
command_string() {
|
||||
local arg
|
||||
local result=""
|
||||
for arg in "$@"; do
|
||||
result="$result $(printf '%q' "$arg")"
|
||||
done
|
||||
printf '%s' "${result# }"
|
||||
}
|
||||
|
||||
read_env_value() {
|
||||
local key="$1"
|
||||
local default_value="$2"
|
||||
local env_file="$DOCKER_DIR/.env"
|
||||
local value=""
|
||||
|
||||
if [ -f "$env_file" ]; then
|
||||
value="$(awk -F= -v key="$key" '
|
||||
$0 !~ /^[[:space:]]*#/ && $1 == key {
|
||||
sub(/^[^=]*=/, "")
|
||||
print
|
||||
}
|
||||
' "$env_file" | tail -n 1)"
|
||||
fi
|
||||
|
||||
if [ -z "$value" ]; then
|
||||
printf '%s' "$default_value"
|
||||
return
|
||||
fi
|
||||
|
||||
value="${value%\"}"
|
||||
value="${value#\"}"
|
||||
value="${value%\'}"
|
||||
value="${value#\'}"
|
||||
printf '%s' "$value"
|
||||
}
|
||||
|
||||
parse_args() {
|
||||
while [ "$#" -gt 0 ]; do
|
||||
case "$1" in
|
||||
--yes)
|
||||
YES=true
|
||||
DRY_RUN=false
|
||||
;;
|
||||
--dry-run)
|
||||
DRY_RUN=true
|
||||
;;
|
||||
--repo-root)
|
||||
[ "$#" -ge 2 ] || fail "--repo-root requires a path"
|
||||
REPO_ROOT="$2"
|
||||
shift
|
||||
;;
|
||||
--smoke-url)
|
||||
[ "$#" -ge 2 ] || fail "--smoke-url requires a URL"
|
||||
SMOKE_URL="$2"
|
||||
shift
|
||||
;;
|
||||
--skip-smoke)
|
||||
SKIP_SMOKE=true
|
||||
;;
|
||||
--skip-migration)
|
||||
SKIP_MIGRATION=true
|
||||
;;
|
||||
--timeout)
|
||||
[ "$#" -ge 2 ] || fail "--timeout requires seconds"
|
||||
TIMEOUT_SECONDS="$2"
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
fail "Unknown option: $1"
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
}
|
||||
|
||||
validate_number() {
|
||||
case "$TIMEOUT_SECONDS" in
|
||||
''|*[!0-9]*)
|
||||
fail "--timeout must be a positive integer"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
require_docker() {
|
||||
command -v docker >/dev/null 2>&1 || fail "docker command not found"
|
||||
docker compose version >/dev/null 2>&1 || fail "docker compose is not available"
|
||||
}
|
||||
|
||||
validate_environment() {
|
||||
CURRENT_PHASE="validate"
|
||||
REPO_ROOT="$(cd "$REPO_ROOT" && pwd -P)"
|
||||
DOCKER_DIR="$REPO_ROOT/docker"
|
||||
|
||||
[ -d "$DOCKER_DIR" ] || fail "Docker directory not found: $DOCKER_DIR"
|
||||
[ -f "$DOCKER_DIR/docker-compose.yaml" ] || fail "docker-compose.yaml not found in $DOCKER_DIR"
|
||||
[ -f "$DOCKER_DIR/.env" ] || fail ".env not found in $DOCKER_DIR"
|
||||
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
[ "$YES" = true ] || fail "Destructive reset requires --yes"
|
||||
[ "${ALLOW_DIFY_TEST_RESET:-}" = "true" ] || fail "ALLOW_DIFY_TEST_RESET=true is required"
|
||||
[ "${DIFY_ENV_NAME:-}" = "test" ] || fail "DIFY_ENV_NAME=test is required"
|
||||
require_docker
|
||||
fi
|
||||
}
|
||||
|
||||
acquire_lock() {
|
||||
CURRENT_PHASE="lock"
|
||||
local env_name="${DIFY_ENV_NAME:-dry-run}"
|
||||
LOCK_DIR="${TMPDIR:-/tmp}/dify-test-reset-${env_name}.lock"
|
||||
|
||||
if ! mkdir "$LOCK_DIR" 2>/dev/null; then
|
||||
fail "Reset lock is already held: $LOCK_DIR"
|
||||
fi
|
||||
|
||||
printf '%s\n' "$$" > "$LOCK_DIR/pid"
|
||||
trap cleanup EXIT
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
if [ -n "$LOCK_DIR" ] && [ -d "$LOCK_DIR" ]; then
|
||||
rm -rf "$LOCK_DIR"
|
||||
fi
|
||||
}
|
||||
|
||||
compose() {
|
||||
local args=(compose --env-file "$DOCKER_DIR/.env" -f "$DOCKER_DIR/docker-compose.yaml")
|
||||
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
|
||||
args+=(-p "$DIFY_COMPOSE_PROJECT")
|
||||
fi
|
||||
|
||||
docker "${args[@]}" "$@"
|
||||
}
|
||||
|
||||
compose_project_name() {
|
||||
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
|
||||
printf '%s' "$DIFY_COMPOSE_PROJECT"
|
||||
return
|
||||
fi
|
||||
|
||||
if [ -n "${COMPOSE_PROJECT_NAME:-}" ]; then
|
||||
printf '%s' "$COMPOSE_PROJECT_NAME"
|
||||
return
|
||||
fi
|
||||
|
||||
local env_project
|
||||
env_project="$(read_env_value COMPOSE_PROJECT_NAME "")"
|
||||
if [ -n "$env_project" ]; then
|
||||
printf '%s' "$env_project"
|
||||
return
|
||||
fi
|
||||
|
||||
basename "$DOCKER_DIR"
|
||||
}
|
||||
|
||||
active_db_service() {
|
||||
local db_type
|
||||
db_type="$(read_env_value DB_TYPE postgresql)"
|
||||
case "$db_type" in
|
||||
postgresql|'')
|
||||
printf '%s\n' "db_postgres"
|
||||
;;
|
||||
mysql)
|
||||
printf '%s\n' "db_mysql"
|
||||
;;
|
||||
oceanbase)
|
||||
printf '%s\n' "oceanbase"
|
||||
;;
|
||||
*)
|
||||
printf '%s\n' "$db_type"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
active_vector_service() {
|
||||
local vector_store
|
||||
vector_store="$(read_env_value VECTOR_STORE weaviate)"
|
||||
case "$vector_store" in
|
||||
''|none|external)
|
||||
return 0
|
||||
;;
|
||||
pgvecto-rs|pgvecto_rs)
|
||||
printf '%s\n' "pgvecto-rs"
|
||||
;;
|
||||
milvus)
|
||||
printf '%s\n' "milvus-standalone"
|
||||
;;
|
||||
elasticsearch|opensearch|weaviate|qdrant|pgvector|chroma|oceanbase|seekdb|couchbase-server|iris)
|
||||
printf '%s\n' "$vector_store"
|
||||
;;
|
||||
couchbase)
|
||||
printf '%s\n' "couchbase-server"
|
||||
;;
|
||||
*)
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
safe_runtime_path() {
|
||||
local rel_path="$1"
|
||||
case "$rel_path" in
|
||||
""|"/"| "." | ".." | *".."* | /*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
case "$rel_path" in
|
||||
volumes/*)
|
||||
return 0
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
safe_named_volume() {
|
||||
local volume="$1"
|
||||
case "$volume" in
|
||||
""|*"/"*|*" "*|*$'\t'*|*$'\n'*|*$'\r'*)
|
||||
return 1
|
||||
;;
|
||||
*[!a-zA-Z0-9_.-]*)
|
||||
return 1
|
||||
;;
|
||||
*)
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
volume_exists() {
|
||||
docker volume inspect "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
append_unique_volume() {
|
||||
local candidate="$1"
|
||||
local existing
|
||||
[ -n "$candidate" ] || return 0
|
||||
|
||||
for existing in "${RESOLVED_VOLUME_NAMES[@]}"; do
|
||||
if [ "$existing" = "$candidate" ]; then
|
||||
return
|
||||
fi
|
||||
done
|
||||
|
||||
RESOLVED_VOLUME_NAMES+=("$candidate")
|
||||
}
|
||||
|
||||
resolve_named_volume_names() {
|
||||
local logical_name="$1"
|
||||
local project_name
|
||||
local candidate
|
||||
local volume_list
|
||||
local status
|
||||
RESOLVED_VOLUME_NAMES=()
|
||||
|
||||
project_name="$(compose_project_name)"
|
||||
|
||||
set +e
|
||||
volume_list="$(docker volume ls -q \
|
||||
--filter "label=com.docker.compose.project=$project_name" \
|
||||
--filter "label=com.docker.compose.volume=$logical_name" 2>/dev/null)"
|
||||
status=$?
|
||||
set -e
|
||||
|
||||
if [ "$status" -ne 0 ]; then
|
||||
fail "Failed to list Docker volumes for Compose project $project_name"
|
||||
fi
|
||||
|
||||
while IFS= read -r candidate; do
|
||||
append_unique_volume "$candidate"
|
||||
done <<< "$volume_list"
|
||||
|
||||
for candidate in "${project_name}_${logical_name}" "$logical_name"; do
|
||||
if volume_exists "$candidate"; then
|
||||
append_unique_volume "$candidate"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
collect_preserved_paths() {
|
||||
PRESERVED_PATHS=()
|
||||
local rel_path
|
||||
for rel_path in "${PRESERVE_PATHS[@]}"; do
|
||||
if [ -e "$DOCKER_DIR/$rel_path" ]; then
|
||||
PRESERVED_PATHS+=("$rel_path")
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
print_plan() {
|
||||
CURRENT_PHASE="plan"
|
||||
local db_service
|
||||
local vector_service
|
||||
db_service="$(active_db_service)"
|
||||
vector_service="$(active_vector_service || true)"
|
||||
|
||||
collect_preserved_paths
|
||||
|
||||
log "Reset mode: $([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
|
||||
log "Repository root: $REPO_ROOT"
|
||||
log "Docker directory: $DOCKER_DIR"
|
||||
log "Compose project: $(compose_project_name)"
|
||||
log "Database service: $db_service"
|
||||
log "Vector service: ${vector_service:-<external-or-none>}"
|
||||
log "Timeout: ${TIMEOUT_SECONDS}s"
|
||||
|
||||
printf '\nPlanned runtime path deletions:\n'
|
||||
local rel_path
|
||||
for rel_path in "${RUNTIME_PATHS[@]}"; do
|
||||
printf ' - %s\n' "$rel_path"
|
||||
done
|
||||
|
||||
printf '\nPlanned named volume deletions:\n'
|
||||
local volume
|
||||
for volume in "${NAMED_VOLUMES[@]}"; do
|
||||
printf ' - %s (Compose project: %s)\n' "$volume" "$(compose_project_name)"
|
||||
done
|
||||
|
||||
printf '\nPreserved configuration paths found:\n'
|
||||
for rel_path in "${PRESERVED_PATHS[@]}"; do
|
||||
printf ' - %s\n' "$rel_path"
|
||||
done
|
||||
|
||||
printf '\nCommands:\n'
|
||||
printf ' - docker compose down --remove-orphans\n'
|
||||
printf ' - delete allowlisted runtime paths and named volumes\n'
|
||||
printf ' - docker compose up -d %s redis%s\n' "$db_service" "${vector_service:+ $vector_service}"
|
||||
printf ' - docker compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api\n'
|
||||
printf ' - docker compose up -d\n'
|
||||
printf ' - health checks and smoke check\n\n'
|
||||
}
|
||||
|
||||
delete_runtime_paths() {
|
||||
CURRENT_PHASE="delete-runtime-data"
|
||||
local rel_path
|
||||
local abs_path
|
||||
|
||||
for rel_path in "${RUNTIME_PATHS[@]}"; do
|
||||
safe_runtime_path "$rel_path" || fail "Unsafe runtime path in allowlist: $rel_path"
|
||||
abs_path="$DOCKER_DIR/$rel_path"
|
||||
|
||||
if [ ! -e "$abs_path" ]; then
|
||||
SKIPPED_PATHS+=("$rel_path (absent)")
|
||||
continue
|
||||
fi
|
||||
|
||||
DELETED_PATHS+=("$rel_path")
|
||||
run_cmd rm -rf -- "$abs_path"
|
||||
done
|
||||
}
|
||||
|
||||
delete_named_volumes() {
|
||||
CURRENT_PHASE="delete-runtime-volumes"
|
||||
local logical_name
|
||||
local actual_name
|
||||
|
||||
for logical_name in "${NAMED_VOLUMES[@]}"; do
|
||||
safe_named_volume "$logical_name" || fail "Unsafe named volume in allowlist: $logical_name"
|
||||
resolve_named_volume_names "$logical_name"
|
||||
|
||||
if [ "${#RESOLVED_VOLUME_NAMES[@]}" -eq 0 ]; then
|
||||
SKIPPED_NAMED_VOLUMES+=("$logical_name (absent)")
|
||||
continue
|
||||
fi
|
||||
|
||||
for actual_name in "${RESOLVED_VOLUME_NAMES[@]}"; do
|
||||
DELETED_NAMED_VOLUMES+=("$actual_name")
|
||||
run_cmd docker volume rm "$actual_name"
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
stop_stack() {
|
||||
CURRENT_PHASE="stop-stack"
|
||||
run_cmd compose down --remove-orphans
|
||||
}
|
||||
|
||||
start_middleware() {
|
||||
CURRENT_PHASE="start-middleware"
|
||||
local db_service
|
||||
local vector_service
|
||||
local services=()
|
||||
|
||||
db_service="$(active_db_service)"
|
||||
vector_service="$(active_vector_service || true)"
|
||||
services+=("$db_service" "redis")
|
||||
|
||||
if [ -n "$vector_service" ]; then
|
||||
services+=("$vector_service")
|
||||
fi
|
||||
|
||||
run_cmd compose up -d "${services[@]}"
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
wait_for_services "${services[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
run_migration() {
|
||||
CURRENT_PHASE="migration"
|
||||
if [ "$SKIP_MIGRATION" = true ]; then
|
||||
HEALTH_RESULTS+=("migration:skipped")
|
||||
return
|
||||
fi
|
||||
|
||||
run_cmd compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api
|
||||
HEALTH_RESULTS+=("migration:ok")
|
||||
}
|
||||
|
||||
start_full_stack() {
|
||||
CURRENT_PHASE="start-full-stack"
|
||||
run_cmd compose up -d
|
||||
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
wait_for_services api web worker nginx
|
||||
wait_if_service_exists plugin_daemon
|
||||
fi
|
||||
}
|
||||
|
||||
container_status() {
|
||||
local service="$1"
|
||||
local container_id
|
||||
container_id="$(compose ps -q "$service" 2>/dev/null || true)"
|
||||
[ -n "$container_id" ] || return 1
|
||||
|
||||
local health
|
||||
health="$(docker inspect --format '{{if .State.Health}}{{.State.Health.Status}}{{else}}{{.State.Status}}{{end}}' "$container_id" 2>/dev/null || true)"
|
||||
printf '%s' "$health"
|
||||
}
|
||||
|
||||
wait_if_service_exists() {
|
||||
local service="$1"
|
||||
if [ -n "$(compose ps -q "$service" 2>/dev/null || true)" ]; then
|
||||
wait_for_services "$service"
|
||||
fi
|
||||
}
|
||||
|
||||
wait_for_services() {
|
||||
local service
|
||||
for service in "$@"; do
|
||||
wait_for_service "$service"
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_service() {
|
||||
local service="$1"
|
||||
local deadline=$(( $(date +%s) + TIMEOUT_SECONDS ))
|
||||
local status=""
|
||||
|
||||
log "Waiting for service: $service"
|
||||
while [ "$(date +%s)" -le "$deadline" ]; do
|
||||
status="$(container_status "$service" || true)"
|
||||
case "$status" in
|
||||
healthy|running)
|
||||
HEALTH_RESULTS+=("$service:$status")
|
||||
return 0
|
||||
;;
|
||||
unhealthy|exited|dead)
|
||||
HEALTH_RESULTS+=("$service:$status")
|
||||
fail "Service $service reached failure status: $status"
|
||||
;;
|
||||
esac
|
||||
sleep 3
|
||||
done
|
||||
|
||||
HEALTH_RESULTS+=("$service:timeout")
|
||||
fail "Timed out waiting for service: $service"
|
||||
}
|
||||
|
||||
default_smoke_url() {
|
||||
if [ -n "$SMOKE_URL" ]; then
|
||||
printf '%s' "$SMOKE_URL"
|
||||
return
|
||||
fi
|
||||
|
||||
if [ -n "${RESET_TARGET_DOMAIN:-}" ]; then
|
||||
local https_enabled
|
||||
https_enabled="$(read_env_value NGINX_HTTPS_ENABLED false)"
|
||||
if [ "$https_enabled" = "true" ]; then
|
||||
printf 'https://%s' "$RESET_TARGET_DOMAIN"
|
||||
else
|
||||
printf 'http://%s' "$RESET_TARGET_DOMAIN"
|
||||
fi
|
||||
return
|
||||
fi
|
||||
|
||||
local port
|
||||
port="$(read_env_value EXPOSE_NGINX_PORT 80)"
|
||||
printf 'http://localhost:%s' "$port"
|
||||
}
|
||||
|
||||
run_smoke_check() {
|
||||
CURRENT_PHASE="smoke"
|
||||
if [ "$SKIP_SMOKE" = true ]; then
|
||||
SMOKE_RESULT="skipped"
|
||||
return
|
||||
fi
|
||||
|
||||
local url
|
||||
url="$(default_smoke_url)"
|
||||
if [ "$DRY_RUN" = true ]; then
|
||||
SMOKE_RESULT="planned:$url"
|
||||
printf '+ curl -fsS --max-time 10 %q\n' "$url"
|
||||
return
|
||||
fi
|
||||
|
||||
curl -fsS --max-time 10 "$url" >/dev/null || fail "Smoke check failed: $url"
|
||||
SMOKE_RESULT="ok:$url"
|
||||
}
|
||||
|
||||
print_report() {
|
||||
local status="${1:-success}"
|
||||
local end_time
|
||||
end_time="$(date +%s)"
|
||||
|
||||
printf '\nReset report\n'
|
||||
printf '============\n'
|
||||
printf 'status: %s\n' "$status"
|
||||
printf 'environment: %s\n' "${DIFY_ENV_NAME:-<unset>}"
|
||||
printf 'repo_root: %s\n' "${REPO_ROOT:-<unset>}"
|
||||
printf 'phase: %s\n' "$CURRENT_PHASE"
|
||||
printf 'duration_seconds: %s\n' "$(( end_time - START_TIME ))"
|
||||
printf 'mode: %s\n' "$([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
|
||||
|
||||
printf '\ndeleted_runtime_paths:\n'
|
||||
if [ "${#DELETED_PATHS[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${DELETED_PATHS[@]}"
|
||||
fi
|
||||
|
||||
printf '\nskipped_runtime_paths:\n'
|
||||
if [ "${#SKIPPED_PATHS[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${SKIPPED_PATHS[@]}"
|
||||
fi
|
||||
|
||||
printf '\ndeleted_named_volumes:\n'
|
||||
if [ "${#DELETED_NAMED_VOLUMES[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${DELETED_NAMED_VOLUMES[@]}"
|
||||
fi
|
||||
|
||||
printf '\nskipped_named_volumes:\n'
|
||||
if [ "${#SKIPPED_NAMED_VOLUMES[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${SKIPPED_NAMED_VOLUMES[@]}"
|
||||
fi
|
||||
|
||||
printf '\npreserved_paths:\n'
|
||||
if [ "${#PRESERVED_PATHS[@]}" -eq 0 ]; then
|
||||
printf ' - <none found>\n'
|
||||
else
|
||||
printf ' - %s\n' "${PRESERVED_PATHS[@]}"
|
||||
fi
|
||||
|
||||
printf '\nhealth_results:\n'
|
||||
if [ "${#HEALTH_RESULTS[@]}" -eq 0 ]; then
|
||||
printf ' - <not run>\n'
|
||||
else
|
||||
printf ' - %s\n' "${HEALTH_RESULTS[@]}"
|
||||
fi
|
||||
|
||||
printf '\nsmoke_result: %s\n' "$SMOKE_RESULT"
|
||||
}
|
||||
|
||||
main() {
|
||||
parse_args "$@"
|
||||
validate_number
|
||||
validate_environment
|
||||
acquire_lock
|
||||
print_plan
|
||||
|
||||
if [ "$DRY_RUN" = true ]; then
|
||||
run_smoke_check
|
||||
print_report "dry-run"
|
||||
return 0
|
||||
fi
|
||||
|
||||
stop_stack
|
||||
delete_runtime_paths
|
||||
delete_named_volumes
|
||||
start_middleware
|
||||
run_migration
|
||||
start_full_stack
|
||||
run_smoke_check
|
||||
CURRENT_PHASE="complete"
|
||||
print_report "success"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@ -120,7 +120,11 @@ jobs:
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir ../packages/contracts/openapi --markdown-dir openapi/markdown --keep-swagger-json
|
||||
|
||||
- name: Generate frontend contracts
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: cd packages/contracts && vp run gen-api-contract-from-openapi
|
||||
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
|
||||
14
.github/workflows/style.yml
vendored
14
.github/workflows/style.yml
vendored
@ -77,6 +77,8 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
e2e/**
|
||||
sdks/nodejs-client/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
@ -94,14 +96,14 @@ jobs:
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
path: .eslintcache
|
||||
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
working-directory: .
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
@ -113,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
working-directory: .
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
@ -125,7 +127,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
path: .eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/settings.example.json
|
||||
!.vscode/README.md
|
||||
api/.vscode
|
||||
# vscode Code History Extension
|
||||
@ -249,3 +250,5 @@ scripts/stress-test/reports/
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.context/*
|
||||
.eslintcache
|
||||
|
||||
@ -56,44 +56,9 @@ if $api_modified; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if $web_modified; then
|
||||
if $skip_web_checks; then
|
||||
echo "Git operation in progress, skipping web checks"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running ESLint on web module"
|
||||
|
||||
if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
|
||||
web_ts_modified=false
|
||||
else
|
||||
ts_diff_status=$?
|
||||
if [ $ts_diff_status -eq 1 ]; then
|
||||
web_ts_modified=true
|
||||
else
|
||||
echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
|
||||
exit $ts_diff_status
|
||||
fi
|
||||
fi
|
||||
|
||||
cd ./web || exit 1
|
||||
pnpm exec vp staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
if ! npm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||
fi
|
||||
|
||||
echo "Running knip"
|
||||
if ! npm run knip; then
|
||||
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ../
|
||||
if $skip_web_checks; then
|
||||
echo "Git operation in progress, skipping web checks"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
vp staged
|
||||
|
||||
@ -9,6 +9,7 @@ The codebase is split into:
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
- **Dify Agent Backend** (`/dify-agent`): Backend services for managing and executing agent
|
||||
|
||||
## Backend Workflow
|
||||
|
||||
|
||||
17
Makefile
17
Makefile
@ -83,16 +83,15 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "📝 Running core type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
@ -153,8 +152,8 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
|
||||
@ -17,14 +17,15 @@ FROM base AS packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
git g++ \
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev --group evaluation
|
||||
# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
@ -77,7 +78,6 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
git \
|
||||
nodejs=${NODE_PACKAGE_VERSION} \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
|
||||
@ -99,7 +99,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
uv run pyrefly check # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
@ -117,7 +117,7 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
# Capture the decorator return values so static checkers do not treat the hooks as unused.
|
||||
_ = before_request
|
||||
_ = add_trace_headers
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ CLI command modules extracted from `commands.py`.
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .plugin import (
|
||||
backfill_plugin_auto_upgrade,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
install_plugins,
|
||||
@ -15,7 +14,6 @@ from .plugin import (
|
||||
setup_system_trigger_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
)
|
||||
from .rbac import migrate_member_roles_to_rbac
|
||||
from .retention import (
|
||||
archive_workflow_runs,
|
||||
clean_expired_messages,
|
||||
@ -39,7 +37,6 @@ from .vector import (
|
||||
__all__ = [
|
||||
"add_qdrant_index",
|
||||
"archive_workflow_runs",
|
||||
"backfill_plugin_auto_upgrade",
|
||||
"clean_expired_messages",
|
||||
"clean_workflow_runs",
|
||||
"cleanup_orphaned_draft_variables",
|
||||
@ -58,7 +55,6 @@ __all__ = [
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_member_roles_to_rbac",
|
||||
"migrate_oss",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
@ -15,13 +14,11 @@ from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
@ -188,9 +185,9 @@ def transform_datasource_credentials(environment: str):
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
@ -405,110 +402,6 @@ def migrate_data_for_plugin():
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
def _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit: int | None = None):
|
||||
category_count = len(TenantPluginAutoUpgradeStrategy.PluginCategory)
|
||||
stmt = (
|
||||
select(TenantPluginAutoUpgradeStrategy.tenant_id)
|
||||
.group_by(TenantPluginAutoUpgradeStrategy.tenant_id)
|
||||
.having(func.count(func.distinct(TenantPluginAutoUpgradeStrategy.category)) < category_count)
|
||||
.order_by(TenantPluginAutoUpgradeStrategy.tenant_id)
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def _count_auto_upgrade_strategy_tenant_ids(limit: int | None) -> int:
|
||||
candidate_stmt = _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit).subquery()
|
||||
return db.session.scalar(select(func.count()).select_from(candidate_stmt)) or 0
|
||||
|
||||
|
||||
def _iter_auto_upgrade_strategy_tenant_ids(limit: int | None):
|
||||
stmt = _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit).execution_options(yield_per=1000)
|
||||
yield from db.session.scalars(stmt)
|
||||
|
||||
|
||||
@click.command(
|
||||
"backfill-plugin-auto-upgrade",
|
||||
help="Backfill category-scoped plugin auto-upgrade strategies and normalize plugin lists.",
|
||||
)
|
||||
@click.option("--tenant-id", multiple=True, help="Tenant ID to backfill. Can be passed multiple times.")
|
||||
@click.option("--limit", type=int, default=None, help="Maximum number of candidate tenants to process.")
|
||||
@click.option("--batch-size", type=int, default=500, show_default=True, help="Progress reporting batch size.")
|
||||
@click.option("--dry-run", is_flag=True, help="Only print candidate tenant count.")
|
||||
def backfill_plugin_auto_upgrade(
|
||||
tenant_id: tuple[str, ...],
|
||||
limit: int | None,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Backfill historical auto-upgrade strategies after the category column exists.
|
||||
|
||||
Missing category rows are created from the tenant's tool/default row. Pure default
|
||||
strategies become latest for model plugins and fix-only for all other categories.
|
||||
Tenants with include/exclude plugin IDs are split
|
||||
by installed plugin category using plugin daemon metadata.
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
candidate_count = len(tenant_id) if tenant_id else _count_auto_upgrade_strategy_tenant_ids(limit)
|
||||
click.echo(click.style(f"Found {candidate_count} candidate tenants.", fg="yellow"))
|
||||
|
||||
if dry_run:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(click.style(f"Dry run completed. elapsed={elapsed:.2f}s", fg="green"))
|
||||
return
|
||||
|
||||
tenant_ids = list(tenant_id) if tenant_id else _iter_auto_upgrade_strategy_tenant_ids(limit)
|
||||
|
||||
backfilled_count = 0
|
||||
created_count = 0
|
||||
normalized_count = 0
|
||||
skipped_count = 0
|
||||
failed_count = 0
|
||||
for index, current_tenant_id in enumerate(tenant_ids, start=1):
|
||||
try:
|
||||
result = PluginAutoUpgradeService.backfill_strategy_categories(
|
||||
current_tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
click.echo(click.style(f"Failed tenant {current_tenant_id}: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
if result.created_count > 0:
|
||||
backfilled_count += 1
|
||||
created_count += result.created_count
|
||||
elif not result.normalized:
|
||||
skipped_count += 1
|
||||
if result.normalized:
|
||||
normalized_count += 1
|
||||
|
||||
if batch_size > 0 and index % batch_size == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Processed {index}/{candidate_count} tenants. "
|
||||
f"backfilled={backfilled_count}, created_rows={created_count}, "
|
||||
f"normalized={normalized_count}, skipped={skipped_count}, failed={failed_count}, "
|
||||
f"elapsed={time.perf_counter() - start_at:.2f}s",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Backfill plugin auto-upgrade strategy categories completed. "
|
||||
f"backfilled={backfilled_count}, created_rows={created_count}, "
|
||||
f"normalized={normalized_count}, skipped={skipped_count}, failed={failed_count}, "
|
||||
f"elapsed={elapsed:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("extract-plugins", help="Extract plugins.")
|
||||
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
|
||||
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import TenantAccountJoin, TenantAccountRole
|
||||
from services.enterprise.rbac_service import ListOption, RBACService
|
||||
|
||||
|
||||
def _resolve_builtin_role_id(tenant_id: str, operator_account_id: str, legacy_role: str) -> str:
|
||||
"""Resolve a legacy workspace role to the current tenant's builtin RBAC role id.
|
||||
|
||||
The migration replays the old `TenantAccountJoin.role` values onto the
|
||||
RBAC member-role binding API. Builtin RBAC roles are tenant-scoped and
|
||||
identified by runtime ids, so the command must look them up per tenant.
|
||||
"""
|
||||
expected_builtin_name = {
|
||||
TenantAccountRole.OWNER.value: "所有者",
|
||||
TenantAccountRole.ADMIN.value: "管理者",
|
||||
TenantAccountRole.EDITOR.value: "编辑者",
|
||||
TenantAccountRole.NORMAL.value: "普通用户",
|
||||
TenantAccountRole.DATASET_OPERATOR.value: "知识库操作员",
|
||||
}.get(legacy_role)
|
||||
if not expected_builtin_name:
|
||||
raise ValueError(f"Unsupported legacy workspace role: {legacy_role}")
|
||||
|
||||
roles = RBACService.Roles.list(
|
||||
tenant_id=tenant_id,
|
||||
account_id=operator_account_id,
|
||||
options=ListOption(page_number=1, results_per_page=100),
|
||||
).data
|
||||
for role in roles:
|
||||
if role.is_builtin and role.category == "global_system_default" and role.name == expected_builtin_name:
|
||||
return str(role.id)
|
||||
|
||||
raise ValueError(f"Builtin RBAC role not found for tenant={tenant_id}, legacy_role={legacy_role}")
|
||||
|
||||
|
||||
@click.command("rbac-migrate-member-roles", help="Migrate legacy workspace member roles into RBAC member-role bindings.")
|
||||
@click.option("--tenant-id", help="Only migrate a single workspace.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Preview the migration without writing RBAC bindings.")
|
||||
def migrate_member_roles_to_rbac(tenant_id: str | None, dry_run: bool) -> None:
|
||||
"""Backfill RBAC member-role bindings from legacy `TenantAccountJoin.role` data.
|
||||
|
||||
This is an offline migration command for workspaces that already have
|
||||
members in the legacy role model but need matching records in the RBAC
|
||||
member-role binding store.
|
||||
"""
|
||||
click.echo(click.style("Starting RBAC member-role migration.", fg="green"))
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(TenantAccountJoin).order_by(TenantAccountJoin.tenant_id.asc(), TenantAccountJoin.id.asc())
|
||||
if tenant_id:
|
||||
stmt = stmt.where(TenantAccountJoin.tenant_id == tenant_id)
|
||||
|
||||
joins = list(session.scalars(stmt).all())
|
||||
|
||||
if not joins:
|
||||
click.echo(click.style("No workspace members found for migration.", fg="yellow"))
|
||||
return
|
||||
|
||||
owner_account_by_tenant: dict[str, str] = {}
|
||||
resolved_role_ids: dict[tuple[str, str], str] = {}
|
||||
migrated_count = 0
|
||||
|
||||
for join in joins:
|
||||
workspace_id = str(join.tenant_id)
|
||||
member_account_id = str(join.account_id)
|
||||
legacy_role = str(join.role)
|
||||
|
||||
if workspace_id not in owner_account_by_tenant:
|
||||
owner_join = next(
|
||||
(
|
||||
item
|
||||
for item in joins
|
||||
if str(item.tenant_id) == workspace_id and str(item.role) == TenantAccountRole.OWNER.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not owner_join:
|
||||
raise ValueError(f"Workspace owner not found for tenant={workspace_id}")
|
||||
owner_account_by_tenant[workspace_id] = str(owner_join.account_id)
|
||||
|
||||
operator_account_id = owner_account_by_tenant[workspace_id]
|
||||
cache_key = (workspace_id, legacy_role)
|
||||
if cache_key not in resolved_role_ids:
|
||||
resolved_role_ids[cache_key] = _resolve_builtin_role_id(workspace_id, operator_account_id, legacy_role)
|
||||
|
||||
resolved_role_id = resolved_role_ids[cache_key]
|
||||
click.echo(
|
||||
f"tenant={workspace_id} member={member_account_id} legacy_role={legacy_role} -> rbac_role_id={resolved_role_id}"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
continue
|
||||
|
||||
RBACService.MemberRoles.replace(
|
||||
tenant_id=workspace_id,
|
||||
account_id=operator_account_id,
|
||||
member_account_id=member_account_id,
|
||||
role_ids=[resolved_role_id],
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run completed. No RBAC bindings were written.", fg="yellow"))
|
||||
else:
|
||||
click.echo(click.style(f"RBAC member-role migration completed. Migrated {migrated_count} members.", fg="green"))
|
||||
@ -14,6 +14,7 @@ from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.model import App, AppMode, Conversation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,13 +24,16 @@ DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
@click.command(
|
||||
"reset-encrypt-key-pair",
|
||||
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||
"After the reset, all LLM credentials will become invalid, "
|
||||
"requiring re-entry."
|
||||
"After the reset, all LLM credentials and tool provider credentials "
|
||||
"(builtin / API / MCP) will be purged, requiring re-entry. "
|
||||
"Only support SELF_HOSTED mode.",
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
|
||||
"Are you sure you want to reset encrypt key pair? "
|
||||
"This will also purge builtin / API / MCP tool provider records for every tenant. "
|
||||
"This operation cannot be rolled back!",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
@ -53,6 +57,13 @@ def reset_encrypt_key_pair():
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
|
||||
# Purge tool provider records that hold credentials encrypted under the
|
||||
# tenant key. Leaving them in place causes /console/api/workspaces/current/
|
||||
# tool-providers to 500 because decryption fails on stale ciphertext (#35396).
|
||||
session.execute(delete(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant.id))
|
||||
session.execute(delete(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant.id))
|
||||
session.execute(delete(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant.id))
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
|
||||
@ -23,9 +23,10 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
RBAC_ENABLED: bool = Field(
|
||||
description="Enable enterprise RBAC APIs. When disabled, compatibility responses fall back to legacy roles.",
|
||||
ENTERPRISE_DISABLE_RUNTIME_CREDENTIAL_CHECK: bool = Field(
|
||||
default=False,
|
||||
description="If disabled, credential policy check is only performed when saving workflows."
|
||||
"This helps gain runtime performance by trading off consistency.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
@ -1406,32 +1406,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class EvaluationConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for evaluation runtime
|
||||
"""
|
||||
|
||||
EVALUATION_FRAMEWORK: str = Field(
|
||||
description="Evaluation framework to use (ragas/deepeval/none)",
|
||||
default="none",
|
||||
)
|
||||
|
||||
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent evaluation runs per tenant",
|
||||
default=3,
|
||||
)
|
||||
|
||||
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
|
||||
description="Maximum number of rows allowed in an evaluation dataset",
|
||||
default=500,
|
||||
)
|
||||
|
||||
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for a single evaluation task",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1445,7 +1419,6 @@ class FeatureConfig(
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
EvaluationConfig,
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
@ -50,28 +50,30 @@ from .vdb.vastbase_vector_config import VastbaseVectorConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
_VALID_STORAGE_TYPE = Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
]
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
] = Field(
|
||||
STORAGE_TYPE: _VALID_STORAGE_TYPE = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
default=cast(_VALID_STORAGE_TYPE, "opendal"),
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
|
||||
@ -1,36 +1,21 @@
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
import json
|
||||
|
||||
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
|
||||
"decision": "approve",
|
||||
"attachment": {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
|
||||
"type": "document",
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
|
||||
"type": "document",
|
||||
},
|
||||
{
|
||||
"transfer_method": "remote_url",
|
||||
"url": "https://example.com/report.pdf",
|
||||
"type": "document",
|
||||
},
|
||||
],
|
||||
}
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue] = Field(
|
||||
description=(
|
||||
"Submitted human input values keyed by output variable name. "
|
||||
"Use a string for paragraph or select input values, a file mapping for file inputs, "
|
||||
"and a list of file mappings for file-list inputs. Local file mappings use "
|
||||
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
|
||||
"`transfer_method=remote_url` with `url` or `remote_url`."
|
||||
),
|
||||
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
|
||||
)
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
|
||||
|
||||
def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
"""Serialize default values into strings expected by human-input form clients."""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
@ -39,6 +39,7 @@ QueryParamDoc = TypedDict(
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
schema = _swagger_2_compatible_schema(schema)
|
||||
nested_definitions = schema.get("$defs")
|
||||
schema_to_register = dict(schema)
|
||||
if isinstance(nested_definitions, dict):
|
||||
@ -65,6 +66,35 @@ def _register_schema_model(namespace: Namespace, model: type[BaseModel], *, mode
|
||||
)
|
||||
|
||||
|
||||
def _swagger_2_compatible_schema(value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [_swagger_2_compatible_schema(item) for item in value]
|
||||
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
converted = {key: _swagger_2_compatible_schema(child) for key, child in value.items()}
|
||||
any_of = value.get("anyOf")
|
||||
if not isinstance(any_of, list):
|
||||
return converted
|
||||
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
|
||||
]
|
||||
has_null_candidate = any(isinstance(candidate, Mapping) and candidate.get("type") == "null" for candidate in any_of)
|
||||
if not has_null_candidate or len(non_null_candidates) != 1:
|
||||
return converted
|
||||
|
||||
non_null_schema = _swagger_2_compatible_schema(dict(non_null_candidates[0]))
|
||||
if not isinstance(non_null_schema, dict):
|
||||
return converted
|
||||
|
||||
converted.pop("anyOf", None)
|
||||
converted.update(non_null_schema)
|
||||
converted["x-nullable"] = True
|
||||
return converted
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
|
||||
|
||||
|
||||
@ -33,7 +33,6 @@ for module_name in RESOURCE_MODULES:
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
from . import (
|
||||
admin,
|
||||
apikey,
|
||||
extension,
|
||||
feature,
|
||||
@ -108,9 +107,6 @@ from .datasets.rag_pipeline import (
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import evaluation controllers
|
||||
from .evaluation import evaluation
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
@ -120,12 +116,8 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import snippet controllers
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@ -139,8 +131,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
rbac,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -151,7 +141,6 @@ api.add_namespace(console_ns)
|
||||
__all__ = [
|
||||
"account",
|
||||
"activate",
|
||||
"admin",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_providers",
|
||||
@ -178,7 +167,6 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"evaluation",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
@ -209,17 +197,10 @@ __all__ = [
|
||||
"rag_pipeline_draft_variable",
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"rbac",
|
||||
"recommended_app",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
|
||||
@ -1,64 +1,11 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService, LangContentDict
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
desc: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
can_trial: bool = Field(default=False)
|
||||
trial_limit: int = Field(default=0)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class InsertExploreBannerPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
title: str = Field(...)
|
||||
description: str = Field(...)
|
||||
img_src: str = Field(..., alias="img-src")
|
||||
language: str = Field(default="en-US")
|
||||
link: str = Field(...)
|
||||
sort: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload)
|
||||
|
||||
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@ -76,353 +23,3 @@ def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = payload.desc or ""
|
||||
copy_right = payload.copyright or ""
|
||||
privacy_policy = payload.privacy_policy or ""
|
||||
custom_disclaimer = payload.custom_disclaimer or ""
|
||||
else:
|
||||
desc = site.description or payload.desc or ""
|
||||
copy_right = site.copyright or payload.copyright or ""
|
||||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
app_id=app.id,
|
||||
description=desc,
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=payload.language,
|
||||
category=payload.category,
|
||||
position=payload.position,
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = payload.language
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@console_ns.doc("delete_explore_app")
|
||||
@console_ns.doc(description="Remove an app from the explore list")
|
||||
@console_ns.doc(params={"app_id": "Application ID to remove"})
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id: UUID):
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBannerApi(Resource):
|
||||
@console_ns.doc("insert_explore_banner")
|
||||
@console_ns.doc(description="Insert an explore banner")
|
||||
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||
@console_ns.response(201, "Banner inserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
banner = ExporleBanner(
|
||||
content={
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
},
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBannerApi(Resource):
|
||||
@console_ns.doc("delete_explore_banner")
|
||||
@console_ns.doc(description="Delete an explore banner")
|
||||
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||
@console_ns.response(204, "Banner deleted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.stream.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.stream.seek(0)
|
||||
content = file.stream.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
@ -11,6 +11,7 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@ -21,12 +22,6 @@ from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
@ -37,7 +32,7 @@ class ApiKeyItem(ResponseModel):
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
|
||||
@ -3,7 +3,6 @@ import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -31,19 +30,16 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from models.workflow import resolve_workflow_kind
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
@ -181,12 +177,6 @@ class AppTracePayload(BaseModel):
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -203,7 +193,7 @@ class WorkflowPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfigPartial(ResponseModel):
|
||||
@ -217,7 +207,7 @@ class ModelConfigPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
@ -278,7 +268,7 @@ class ModelConfig(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class Site(ResponseModel):
|
||||
@ -321,7 +311,7 @@ class Site(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DeletedTool(ResponseModel):
|
||||
@ -355,9 +345,6 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
workflow_type: str | None = None
|
||||
workflow_kind: str | None = None
|
||||
permission_keys: list[str] = Field(default_factory=list)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -367,7 +354,7 @@ class AppPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetail(ResponseModel):
|
||||
@ -392,15 +379,12 @@ class AppDetail(ResponseModel):
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
access_mode: str | None = None
|
||||
workflow_type: str | None = None
|
||||
workflow_kind: str | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
permission_keys: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetailWithSite(AppDetail):
|
||||
@ -428,22 +412,6 @@ class AppExportResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
def _collect_app_access_permission_keys(access_matrix: enterprise_rbac_service.AppAccessMatrix) -> list[str]:
|
||||
permission_keys: list[str] = []
|
||||
seen_permission_keys: set[str] = set()
|
||||
|
||||
for item in access_matrix.items:
|
||||
if not item.policy:
|
||||
continue
|
||||
for permission_key in item.policy.permission_keys:
|
||||
if permission_key in seen_permission_keys:
|
||||
continue
|
||||
seen_permission_keys.add(permission_key)
|
||||
permission_keys.append(permission_key)
|
||||
|
||||
return permission_keys
|
||||
|
||||
|
||||
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
|
||||
|
||||
register_schema_models(
|
||||
@ -529,20 +497,6 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
if app_pagination.items:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
permission_keys_map = enterprise_rbac_service.RBACService.AppPermissions.batch_get(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
app_ids,
|
||||
)
|
||||
for app in app_pagination.items:
|
||||
app.permission_keys = permission_keys_map.get(str(app.id), [])
|
||||
else:
|
||||
for app in app_pagination.items:
|
||||
app.permission_keys = []
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
@ -574,25 +528,6 @@ class AppListApi(Resource):
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
|
||||
workflow_info_map: dict[str, tuple[str, str]] = {}
|
||||
if workflow_ids:
|
||||
rows = db.session.execute(
|
||||
select(Workflow.id, Workflow.type, Workflow.kind).where(Workflow.id.in_(workflow_ids))
|
||||
).all()
|
||||
workflow_info_map = {
|
||||
str(row.id): (
|
||||
row.type.value if hasattr(row.type, "value") else str(row.type),
|
||||
resolve_workflow_kind(row.kind).value,
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
for app in app_pagination.items:
|
||||
workflow_info = workflow_info_map.get(str(app.workflow_id)) if app.workflow_id else None
|
||||
app.workflow_type = workflow_info[0] if workflow_info else None
|
||||
app.workflow_kind = workflow_info[1] if workflow_info else None
|
||||
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
|
||||
@ -639,7 +574,6 @@ class AppApi(Resource):
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app_service = AppService()
|
||||
|
||||
app_model = app_service.get_app(app_model)
|
||||
@ -648,29 +582,6 @@ class AppApi(Resource):
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
|
||||
if app_model.workflow_id:
|
||||
row = db.session.execute(
|
||||
select(Workflow.type, Workflow.kind).where(Workflow.id == app_model.workflow_id)
|
||||
).first()
|
||||
app_model.workflow_type = (
|
||||
(row.type.value if hasattr(row.type, "value") else str(row.type)) if row else None
|
||||
)
|
||||
app_model.workflow_kind = resolve_workflow_kind(row.kind).value if row else None
|
||||
else:
|
||||
app_model.workflow_type = None
|
||||
app_model.workflow_kind = None
|
||||
|
||||
if dify_config.RBAC_ENABLED:
|
||||
app_access_matrix = enterprise_rbac_service.RBACService.AppAccess.matrix(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
str(app_model.id),
|
||||
)
|
||||
app_model.permission_keys = _collect_app_access_permission_keys(app_access_matrix)
|
||||
else:
|
||||
app_model.permission_keys = []
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
@ -938,10 +849,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Get app trace"""
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@ -955,12 +867,13 @@ class AppTraceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=str(app_id),
|
||||
app_id=app_model.id,
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
@ -25,12 +26,6 @@ class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -65,7 +60,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
|
||||
@ -13,17 +13,12 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
@ -41,14 +36,14 @@ class AppMCPServerResponse(ResponseModel):
|
||||
name: str
|
||||
server_code: str
|
||||
description: str
|
||||
status: str
|
||||
status: AppMCPServerStatus
|
||||
parameters: dict[str, Any] | list[Any] | str
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def _parse_json_string(cls, value: Any) -> Any:
|
||||
def _normalize_parameters(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
@ -59,7 +54,7 @@ class AppMCPServerResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
|
||||
@ -70,7 +65,9 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__]
|
||||
)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@ -85,7 +82,9 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(
|
||||
201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__]
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@ -111,13 +110,15 @@ class AppMCPServerController(Resource):
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__]
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@ -154,7 +155,7 @@ class AppMCPServerRefreshController(Resource):
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
|
||||
@ -37,10 +37,9 @@ from fields.conversation_fields import (
|
||||
JSONValue,
|
||||
MessageFile,
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
@ -144,9 +143,7 @@ class MessageDetailResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class MessageInfiniteScrollPaginationResponse(ResponseModel):
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
@ -9,8 +8,10 @@ from werkzeug.exceptions import BadRequest
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
@ -43,11 +44,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id: UUID):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@get_app_model
|
||||
def get(self, app_model: App):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
|
||||
trace_config = OpsService.get_tracing_app_config(
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider
|
||||
)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
@ -65,13 +69,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
"""Create a new trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@ -90,13 +95,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def patch(self, app_model: App):
|
||||
"""Update an existing trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@ -113,12 +119,13 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def delete(self, app_model: App):
|
||||
"""Delete an existing trace app configuration"""
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -1,21 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.schema import (
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0,
|
||||
register_response_schema_model,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -53,7 +49,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow, WorkflowKind
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
@ -161,23 +157,6 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
keyword: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowTypeConvertQuery(BaseModel):
|
||||
target_type: Literal["workflow", "evaluation"]
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
@ -199,26 +178,24 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowTypeConvertQuery)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SyncDraftWorkflowPayload,
|
||||
AdvancedChatWorkflowRunPayload,
|
||||
IterationNodeRunPayload,
|
||||
LoopNodeRunPayload,
|
||||
DraftWorkflowRunPayload,
|
||||
DraftWorkflowNodeRunPayload,
|
||||
PublishWorkflowPayload,
|
||||
DefaultBlockConfigQuery,
|
||||
ConvertToWorkflowPayload,
|
||||
WorkflowListQuery,
|
||||
WorkflowUpdatePayload,
|
||||
WorkflowFeaturesPayload,
|
||||
WorkflowOnlineUsersPayload,
|
||||
DraftWorkflowTriggerRunPayload,
|
||||
DraftWorkflowTriggerRunAllPayload,
|
||||
)
|
||||
register_response_schema_model(console_ns, WorkflowRunNodeExecutionResponse)
|
||||
|
||||
|
||||
@ -898,54 +875,6 @@ class PublishedWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
|
||||
class EvaluationPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("publish_evaluation_workflow")
|
||||
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@console_ns.response(200, "Evaluation workflow published successfully")
|
||||
@console_ns.response(400, "Invalid workflow or unsupported node type")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Publish draft workflow as evaluation workflow.
|
||||
|
||||
Evaluation workflows cannot include trigger or human-input nodes.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = workflow_service.publish_evaluation_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
# Keep workflow_id aligned with the latest published workflow.
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_default_block_configs")
|
||||
@ -1143,52 +1072,6 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
|
||||
class WorkflowTypeConvertApi(Resource):
|
||||
@console_ns.doc("convert_published_workflow_type")
|
||||
@console_ns.doc(description="Convert current effective published workflow type in-place")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
|
||||
@console_ns.response(200, "Workflow type converted successfully")
|
||||
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
target_type = WorkflowKind.EVALUATION if args.target_type == "evaluation" else WorkflowKind.STANDARD
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
workflow = workflow_service.convert_published_workflow_type(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
target_type=target_type,
|
||||
account=current_user,
|
||||
)
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"workflow_id": workflow.id,
|
||||
"type": workflow.type.value,
|
||||
"kind": workflow.kind_or_standard,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
|
||||
@ -16,6 +16,7 @@ from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -82,9 +83,7 @@ class WorkflowRunForLogResponse(ResponseModel):
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
@ -104,28 +103,10 @@ class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
|
||||
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
|
||||
name: str
|
||||
value: Any = None
|
||||
details: dict[str, Any] | None = None
|
||||
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
|
||||
default=None,
|
||||
validation_alias="node_info",
|
||||
serialization_alias="nodeInfo",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: Any = None
|
||||
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
@ -135,14 +116,7 @@ class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
@field_validator("evaluation", mode="before")
|
||||
@classmethod
|
||||
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
|
||||
return value or []
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
@ -156,9 +130,7 @@ class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
@ -182,8 +154,6 @@ register_schema_models(
|
||||
WorkflowAppLogQuery,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowRunForArchivedLogResponse,
|
||||
WorkflowAppLogEvaluationNodeInfoResponse,
|
||||
WorkflowAppLogEvaluationItemResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowArchivedLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
|
||||
@ -1,22 +1,16 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
@ -51,6 +45,138 @@ class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
class WorkflowCommentAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
avatar: str | None = Field(default=None, exclude=True)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def avatar_url(self) -> str | None:
|
||||
return build_avatar_url(self.avatar)
|
||||
|
||||
|
||||
class WorkflowCommentReply(ResponseModel):
|
||||
id: str
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentMention(ResponseModel):
|
||||
mentioned_user_id: str
|
||||
mentioned_user_account: WorkflowCommentAccount | None = None
|
||||
reply_id: str | None = None
|
||||
|
||||
|
||||
class WorkflowCommentBasic(ResponseModel):
|
||||
id: str
|
||||
position_x: float
|
||||
position_y: float
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
resolved_by_account: WorkflowCommentAccount | None = None
|
||||
reply_count: int
|
||||
mention_count: int
|
||||
participants: list[WorkflowCommentAccount]
|
||||
|
||||
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentBasicList(ResponseModel):
|
||||
data: list[WorkflowCommentBasic]
|
||||
|
||||
|
||||
class WorkflowCommentDetail(ResponseModel):
|
||||
id: str
|
||||
position_x: float
|
||||
position_y: float
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
resolved_by_account: WorkflowCommentAccount | None = None
|
||||
replies: list[WorkflowCommentReply]
|
||||
mentions: list[WorkflowCommentMention]
|
||||
|
||||
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentCreate(ResponseModel):
|
||||
id: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentUpdate(ResponseModel):
|
||||
id: str
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentResolve(ResponseModel):
|
||||
id: str
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
@field_validator("resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreate(ResponseModel):
|
||||
id: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdate(ResponseModel):
|
||||
id: str
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountWithRole,
|
||||
@ -59,17 +185,19 @@ register_schema_models(
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyPayload,
|
||||
)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
WorkflowCommentAccount,
|
||||
WorkflowCommentReply,
|
||||
WorkflowCommentMention,
|
||||
WorkflowCommentBasic,
|
||||
WorkflowCommentBasicList,
|
||||
WorkflowCommentDetail,
|
||||
WorkflowCommentCreate,
|
||||
WorkflowCommentUpdate,
|
||||
WorkflowCommentResolve,
|
||||
WorkflowCommentReplyCreate,
|
||||
WorkflowCommentReplyUpdate,
|
||||
)
|
||||
|
||||
|
||||
@ -80,28 +208,26 @@ class WorkflowCommentListApi(Resource):
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@console_ns.response(200, "Comments retrieved successfully", console_ns.models[WorkflowCommentBasicList.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@console_ns.response(201, "Comment created successfully", console_ns.models[WorkflowCommentCreate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
@ -117,7 +243,7 @@ class WorkflowCommentListApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return dump_response(WorkflowCommentCreate, result), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
@ -127,30 +253,28 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@console_ns.response(200, "Comment retrieved successfully", console_ns.models[WorkflowCommentDetail.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
return dump_response(WorkflowCommentDetail, comment)
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@console_ns.response(200, "Comment updated successfully", console_ns.models[WorkflowCommentUpdate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
@ -167,7 +291,7 @@ class WorkflowCommentDetailApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
return dump_response(WorkflowCommentUpdate, result)
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@ -197,12 +321,11 @@ class WorkflowCommentResolveApi(Resource):
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@console_ns.response(200, "Comment resolved successfully", console_ns.models[WorkflowCommentResolve.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
@ -213,7 +336,7 @@ class WorkflowCommentResolveApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
return dump_response(WorkflowCommentResolve, comment)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
@ -224,12 +347,11 @@ class WorkflowCommentReplyApi(Resource):
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@console_ns.response(201, "Reply created successfully", console_ns.models[WorkflowCommentReplyCreate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
@ -247,7 +369,7 @@ class WorkflowCommentReplyApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return dump_response(WorkflowCommentReplyCreate, result), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
@ -258,12 +380,11 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@console_ns.response(200, "Reply updated successfully", console_ns.models[WorkflowCommentReplyUpdate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
@ -284,7 +405,7 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
return dump_response(WorkflowCommentReplyUpdate, reply)
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
@ -12,7 +10,6 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -80,39 +77,3 @@ class PartnerTenants(Resource):
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
|
||||
_DEBUG_KEY = "billing:debug"
|
||||
_DEBUG_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class DebugDataPayload(BaseModel):
|
||||
type: str = Field(..., min_length=1, description="Data type key")
|
||||
data: str = Field(..., min_length=1, description="Data value to append")
|
||||
|
||||
|
||||
@console_ns.route("/billing/debug/data")
|
||||
class DebugData(Resource):
|
||||
def post(self):
|
||||
body = DebugDataPayload.model_validate(request.get_json(force=True))
|
||||
item = json.dumps({
|
||||
"type": body.type,
|
||||
"data": body.data,
|
||||
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
})
|
||||
redis_client.lpush(_DEBUG_KEY, item)
|
||||
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
|
||||
return {"result": "ok"}, 201
|
||||
|
||||
def get(self):
|
||||
recent = request.args.get("recent", 10, type=int)
|
||||
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
|
||||
return {
|
||||
"data": [
|
||||
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
|
||||
]
|
||||
}
|
||||
|
||||
def delete(self):
|
||||
redis_client.delete(_DEBUG_KEY)
|
||||
return {"result": "ok"}
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -24,7 +21,6 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -33,7 +29,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
@ -56,22 +51,13 @@ from fields.document_fields import document_status_fields
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
@ -141,14 +127,6 @@ def _validate_doc_form(value: str | None) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
def _ensure_permission_keys(dataset: Dataset, *, enabled: bool) -> None:
|
||||
if not enabled:
|
||||
setattr(dataset, "permission_keys", [])
|
||||
return
|
||||
if not isinstance(getattr(dataset, "permission_keys", None), list):
|
||||
setattr(dataset, "permission_keys", [])
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field("", max_length=400)
|
||||
@ -351,19 +329,6 @@ class DatasetListApi(Resource):
|
||||
query.include_all,
|
||||
)
|
||||
|
||||
for dataset in datasets:
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
|
||||
if dify_config.RBAC_ENABLED and datasets:
|
||||
dataset_ids = [str(dataset.id) for dataset in datasets]
|
||||
permission_keys_map = enterprise_rbac_service.RBACService.DatasetPermissions.batch_get(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
dataset_ids,
|
||||
)
|
||||
for dataset in datasets:
|
||||
setattr(dataset, "permission_keys", permission_keys_map.get(str(dataset.id), []))
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
@ -445,7 +410,6 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
@ -470,7 +434,6 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
@ -540,7 +503,6 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@ -1023,432 +985,3 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
|
||||
# ---- Knowledge Base Retrieval Evaluation ----
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
|
||||
class DatasetEvaluationTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Download evaluation dataset template for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
|
||||
class DatasetEvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(
|
||||
session, current_tenant_id, "dataset", dataset_id_str
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration saved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id):
|
||||
"""Save evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
|
||||
class DatasetEvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Start an evaluation run for the knowledge base retrieval."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
|
||||
class DatasetEvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation run history for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
|
||||
class DatasetEvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, run_id):
|
||||
"""Get evaluation run detail including per-item results."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return {
|
||||
"run": _serialize_dataset_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class DatasetEvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, run_id):
|
||||
"""Cancel a running knowledge base evaluation."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
|
||||
class DatasetEvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_metrics")
|
||||
@console_ns.response(200, "Available retrieval metrics retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get available evaluation metrics for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return {
|
||||
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.RETRIEVAL)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
|
||||
class DatasetEvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, file_id):
|
||||
"""Download evaluation test file or result file for the knowledge base."""
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
file_id_str = str(file_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id_str,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
@ -3,18 +3,19 @@ import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
@ -29,17 +30,16 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
document_metadata_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
@ -72,27 +72,94 @@ from ..wraps import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = get_or_create_model("Dataset", dataset_fields)
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
|
||||
document_fields_copy = document_fields.copy()
|
||||
document_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_model = get_or_create_model("Document", document_fields_copy)
|
||||
class DatasetResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
permission: str | None = None
|
||||
data_source_type: str | None = None
|
||||
indexing_technique: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
document_with_segments_fields_copy = document_with_segments_fields.copy()
|
||||
document_with_segments_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
@field_validator("data_source_type", "indexing_technique", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
|
||||
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
|
||||
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
|
||||
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentMetadataResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class DocumentResponse(ResponseModel):
|
||||
id: str
|
||||
position: int | None = None
|
||||
data_source_type: str | None = None
|
||||
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
|
||||
data_source_detail_dict: Any = None
|
||||
dataset_process_rule_id: str | None = None
|
||||
name: str
|
||||
created_from: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
tokens: int | None = None
|
||||
indexing_status: str | None = None
|
||||
error: str | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
archived: bool | None = None
|
||||
display_status: str | None = None
|
||||
word_count: int | None = None
|
||||
hit_count: int | None = None
|
||||
doc_form: str | None = None
|
||||
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
|
||||
summary_index_status: str | None = None
|
||||
need_summary: bool | None = None
|
||||
|
||||
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
@field_validator("doc_metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "disabled_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentWithSegmentsResponse(DocumentResponse):
|
||||
process_rule_dict: Any = None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
|
||||
|
||||
class DatasetAndDocumentResponse(ResponseModel):
|
||||
dataset: DatasetResponse
|
||||
documents: list[DocumentResponse]
|
||||
batch: str
|
||||
|
||||
|
||||
class DocumentRetryPayload(BaseModel):
|
||||
@ -107,6 +174,11 @@ class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentMetadataUpdatePayload(BaseModel):
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any = None
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
@ -124,7 +196,13 @@ register_schema_models(
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
DocumentMetadataUpdatePayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
DatasetResponse,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentWithSegmentsResponse,
|
||||
DatasetAndDocumentResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -360,10 +438,10 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
@ -401,7 +479,9 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -429,12 +509,13 @@ class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(
|
||||
201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@ -482,9 +563,9 @@ class DatasetInitApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
response = {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
|
||||
return response
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
@ -991,15 +1072,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@console_ns.doc("update_document_metadata")
|
||||
@console_ns.doc(description="Update document metadata")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDocumentMetadataRequest",
|
||||
{
|
||||
"doc_type": fields.String(description="Document type"),
|
||||
"doc_metadata": fields.Raw(description="Document metadata"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Document metadata updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@ -1012,10 +1085,10 @@ class DocumentMetadataApi(DocumentResource):
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
req_data = request.get_json()
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
doc_type = req_data.get("doc_type")
|
||||
doc_metadata = req_data.get("doc_metadata")
|
||||
doc_type = req_data.doc_type
|
||||
doc_metadata = req_data.doc_metadata
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1197,7 +1270,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_model)
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
@ -1215,7 +1288,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return document
|
||||
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -19,12 +20,6 @@ from ..wraps import (
|
||||
)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
@ -61,7 +56,7 @@ class HitTestingSegment(ResponseModel):
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
|
||||
@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
raise ValueError("Invalid hit testing records response")
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
raise ValueError("Invalid hit testing record response")
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
|
||||
limit=10,
|
||||
)
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# Evaluation controller module
|
||||
@ -1,993 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.workflow import WorkflowListQuery
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.member_fields import simple_account_fields
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Dataset
|
||||
from models.evaluation import EvaluationTargetType
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVALUATE_TARGET_TYPES = {
|
||||
EvaluationTargetType.APPS.value,
|
||||
EvaluationTargetType.SNIPPETS.value,
|
||||
}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
"""Query parameters for version endpoint."""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
VersionQuery,
|
||||
)
|
||||
|
||||
|
||||
# Response field definitions
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_fields = {
|
||||
"created_at": TimestampField,
|
||||
"created_by": fields.String,
|
||||
"test_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTestFile",
|
||||
file_info_fields,
|
||||
)
|
||||
),
|
||||
"result_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationResultFile",
|
||||
file_info_fields,
|
||||
),
|
||||
allow_null=True,
|
||||
),
|
||||
"version": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_list_model = console_ns.model(
|
||||
"EvaluationLogList",
|
||||
{
|
||||
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||
},
|
||||
)
|
||||
|
||||
evaluation_default_metric_node_info_fields = {
|
||||
"node_id": fields.String,
|
||||
"type": fields.String,
|
||||
"title": fields.String,
|
||||
}
|
||||
evaluation_default_metric_item_fields = {
|
||||
"metric": fields.String,
|
||||
"value_type": fields.String,
|
||||
"node_info_list": fields.List(
|
||||
fields.Nested(
|
||||
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
customized_metrics_fields = {
|
||||
"evaluation_workflow_id": fields.String,
|
||||
"input_fields": fields.Raw,
|
||||
"output_fields": fields.Raw,
|
||||
}
|
||||
|
||||
judgment_condition_fields = {
|
||||
"variable_selector": fields.List(fields.String),
|
||||
"comparison_operator": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
judgment_config_fields = {
|
||||
"logical_operator": fields.String,
|
||||
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
|
||||
}
|
||||
|
||||
evaluation_detail_fields = {
|
||||
"evaluation_model": fields.String,
|
||||
"evaluation_model_provider": fields.String,
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
|
||||
allow_null=True,
|
||||
),
|
||||
"customized_metrics": fields.Nested(
|
||||
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
"judgment_config": fields.Nested(
|
||||
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
}
|
||||
|
||||
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||
|
||||
available_evaluation_workflow_list_fields = {
|
||||
"id": fields.String,
|
||||
"app_id": fields.String,
|
||||
"app_name": fields.String,
|
||||
"type": fields.String,
|
||||
"kind": fields.String,
|
||||
"version": fields.String,
|
||||
"marked_name": fields.String,
|
||||
"marked_comment": fields.String,
|
||||
"hash": fields.String,
|
||||
"created_by": fields.Nested(simple_account_fields),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_fields = {
|
||||
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_model = console_ns.model(
|
||||
"AvailableEvaluationWorkflowPagination",
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
)
|
||||
|
||||
evaluation_default_metrics_response_model = console_ns.model(
|
||||
"EvaluationDefaultMetricsResponse",
|
||||
{
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
evaluation_dataset_columns_response_model = console_ns.model(
|
||||
"EvaluationDatasetColumnsResponse",
|
||||
{
|
||||
"columns": fields.List(
|
||||
fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTemplateColumn",
|
||||
{
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_evaluation_target[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (apps or snippets).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
if target_type not in EVALUATE_TARGET_TYPES:
|
||||
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
target_id = str(target_id)
|
||||
|
||||
# Remove path parameters
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet] | None = None
|
||||
|
||||
if target_type == EvaluationTargetType.APPS.value:
|
||||
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
elif target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
|
||||
kwargs["target"] = target
|
||||
kwargs["target_type"] = target_type
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _load_evaluation_run_request_and_dataset(tenant_id: str) -> tuple[EvaluationRunRequest, bytes, str]:
|
||||
"""Validate the run payload and load the uploaded dataset bytes."""
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=tenant_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
return run_request, dataset_content, upload_file.name
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
class EvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation configuration for the target.
|
||||
|
||||
Returns evaluation configuration including model settings,
|
||||
metrics config, and judgement conditions.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
|
||||
}
|
||||
|
||||
@console_ns.doc("save_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation configuration saved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Save evaluation configuration for the target.
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
body = request.get_json(force=True)
|
||||
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/template-columns")
|
||||
class EvaluationTemplateColumnsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_template_columns")
|
||||
@console_ns.response(200, "Evaluation dataset columns resolved", evaluation_dataset_columns_response_model)
|
||||
@console_ns.response(400, "Invalid request body")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""Return the dataset template columns implied by the current evaluation config."""
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
return {
|
||||
"columns": EvaluationService.get_dataset_column_names(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
data=config_data,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||
class EvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation run history for the target.
|
||||
|
||||
Returns a paginated list of evaluation runs.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run1")
|
||||
class EvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Expects JSON body with:
|
||||
- file_id: uploaded dataset file ID
|
||||
- evaluation_model: evaluation model name
|
||||
- evaluation_model_provider: evaluation model provider
|
||||
- default_metrics: list of default metric objects
|
||||
- customized_metrics: customized metrics object (optional)
|
||||
- judgment_config: judgment conditions config (optional)
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
if target_type == EvaluationTargetType.APPS.value:
|
||||
evaluation_run = EvaluationService.start_stub_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
dataset_filename=dataset_filename,
|
||||
run_request=run_request,
|
||||
)
|
||||
else:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
dataset_filename=dataset_filename,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
|
||||
class EvaluationRunRealApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run_real")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""Start the real evaluation execution flow on the temporary dev path."""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
dataset_filename=dataset_filename,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
|
||||
class EvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""
|
||||
Get evaluation run detail including items.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"run": _serialize_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class EvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""Cancel a running evaluation."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return _serialize_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
|
||||
class EvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics retrieved")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get available evaluation metrics for the current framework.
|
||||
"""
|
||||
result = {}
|
||||
for category in EvaluationCategory:
|
||||
if category in EvaluationService.CONSOLE_DISABLED_CATEGORIES:
|
||||
continue
|
||||
result[category.value] = EvaluationService.get_supported_metrics(category)
|
||||
return {"metrics": result}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
|
||||
class EvaluationDefaultMetricsApi(Resource):
|
||||
@console_ns.doc(
|
||||
"get_evaluation_default_metrics_with_nodes",
|
||||
description=(
|
||||
"List default metrics supported by the current evaluation framework with matching nodes "
|
||||
"from the target's published workflow only (draft is ignored)."
|
||||
),
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Default metrics and node candidates for the published workflow",
|
||||
evaluation_default_metrics_response_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
return {
|
||||
"default_metrics": [
|
||||
m.model_dump() for m in EvaluationService.filter_console_default_metrics(default_metrics)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
|
||||
class EvaluationNodeInfoApi(Resource):
|
||||
@console_ns.doc("get_evaluation_node_info")
|
||||
@console_ns.response(200, "Node info grouped by metric")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""Return workflow/snippet node info grouped by requested metrics.
|
||||
|
||||
Request body (JSON):
|
||||
- metrics: list[str] | None – metric names to query; omit or pass
|
||||
an empty list to get all nodes under key ``"all"``.
|
||||
|
||||
Response:
|
||||
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
|
||||
"""
|
||||
body = request.get_json(silent=True) or {}
|
||||
metrics: list[str] | None = body.get("metrics") or None
|
||||
|
||||
result = EvaluationService.get_nodes_for_metrics(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
metrics=metrics,
|
||||
)
|
||||
if not metrics:
|
||||
result = {
|
||||
"all": [
|
||||
node
|
||||
for node in result.get("all", [])
|
||||
if node.get("type") not in EvaluationService.CONSOLE_DISABLED_CATEGORIES
|
||||
]
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
metric: nodes
|
||||
for metric, nodes in result.items()
|
||||
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/evaluation/available-metrics")
|
||||
class EvaluationAvailableMetricsApi(Resource):
|
||||
@console_ns.doc("get_available_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics list")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Return the centrally-defined list of evaluation metrics."""
|
||||
return {
|
||||
"metrics": [
|
||||
metric
|
||||
for metric in EvaluationService.get_available_metrics()
|
||||
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||
class EvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated successfully")
|
||||
@console_ns.response(404, "Target or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||
class EvaluationVersionApi(Resource):
|
||||
@console_ns.doc("get_evaluation_version_detail")
|
||||
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||
@console_ns.response(200, "Version details retrieved successfully")
|
||||
@console_ns.response(404, "Target or version not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation target version details.
|
||||
|
||||
Returns the workflow graph for the specified version.
|
||||
"""
|
||||
version = request.args.get("version")
|
||||
|
||||
if not version:
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
graph = {}
|
||||
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/available-evaluation-workflows")
|
||||
class AvailableEvaluationWorkflowsApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("list_available_evaluation_workflows")
|
||||
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Available evaluation workflows retrieved",
|
||||
available_evaluation_workflow_pagination_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self):
|
||||
"""List published evaluation-type workflows for the current tenant (cross-app)."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
keyword = args.keyword
|
||||
|
||||
if user_id and user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.list_published_evaluation_workflows(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
keyword=keyword,
|
||||
)
|
||||
|
||||
app_ids = {w.app_id for w in workflows}
|
||||
if app_ids:
|
||||
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
else:
|
||||
app_names = {}
|
||||
|
||||
items = []
|
||||
for wf in workflows:
|
||||
items.append(
|
||||
{
|
||||
"id": wf.id,
|
||||
"app_id": wf.app_id,
|
||||
"app_name": app_names.get(wf.app_id, ""),
|
||||
"type": wf.type.value,
|
||||
"kind": wf.kind_or_standard,
|
||||
"version": wf.version,
|
||||
"marked_name": wf.marked_name,
|
||||
"marked_comment": wf.marked_comment,
|
||||
"hash": wf.unique_hash,
|
||||
"created_by": wf.created_by_account,
|
||||
"created_at": wf.created_at,
|
||||
"updated_by": wf.updated_by_account,
|
||||
"updated_at": wf.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
marshal(
|
||||
{"items": items, "page": page, "limit": limit, "has_more": has_more},
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
|
||||
class EvaluationWorkflowAssociatedTargetsApi(Resource):
|
||||
@console_ns.doc("list_evaluation_workflow_associated_targets")
|
||||
@console_ns.doc(
|
||||
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, workflow_id: str):
|
||||
"""Return all evaluation targets that reference this workflow as customized metrics."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
configs = EvaluationService.list_targets_by_customized_workflow(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
customized_workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
target_ids_by_type: dict[str, list[str]] = {}
|
||||
for cfg in configs:
|
||||
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
|
||||
|
||||
app_names: dict[str, str] = {}
|
||||
if EvaluationTargetType.APPS.value in target_ids_by_type:
|
||||
apps = session.scalars(
|
||||
select(App).where(App.id.in_(target_ids_by_type[EvaluationTargetType.APPS.value]))
|
||||
).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
|
||||
snippet_names: dict[str, str] = {}
|
||||
if "snippets" in target_ids_by_type:
|
||||
snippets = session.scalars(
|
||||
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
|
||||
).all()
|
||||
snippet_names = {s.id: s.name for s in snippets}
|
||||
|
||||
dataset_names: dict[str, str] = {}
|
||||
if "knowledge_base" in target_ids_by_type:
|
||||
datasets = session.scalars(
|
||||
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
|
||||
).all()
|
||||
dataset_names = {d.id: d.name for d in datasets}
|
||||
|
||||
items = []
|
||||
for cfg in configs:
|
||||
name = ""
|
||||
if cfg.target_type == EvaluationTargetType.APPS.value:
|
||||
name = app_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
name = snippet_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "knowledge_base":
|
||||
name = dataset_names.get(cfg.target_id, "")
|
||||
|
||||
items.append(
|
||||
{
|
||||
"target_type": cfg.target_type,
|
||||
"target_id": cfg.target_id,
|
||||
"target_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": items}, 200
|
||||
|
||||
|
||||
# ---- Serialization Helpers ----
|
||||
|
||||
|
||||
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": run.metrics_summary_dict,
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
@ -16,6 +16,7 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
@ -105,9 +106,7 @@ class InstalledAppResponse(ResponseModel):
|
||||
@field_validator("last_used_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class InstalledAppListResponse(ResponseModel):
|
||||
|
||||
@ -64,28 +64,15 @@ class RecommendedAppListResponse(ResponseModel):
|
||||
categories: list[str]
|
||||
|
||||
|
||||
class LearnDifyAppListResponse(ResponseModel):
|
||||
recommended_apps: list[RecommendedAppResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RecommendedAppsQuery,
|
||||
RecommendedAppInfoResponse,
|
||||
RecommendedAppResponse,
|
||||
RecommendedAppListResponse,
|
||||
LearnDifyAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_language(language: str | None) -> str:
|
||||
if language and language in languages:
|
||||
return language
|
||||
if current_user and current_user.interface_language:
|
||||
return current_user.interface_language
|
||||
return languages[0]
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
|
||||
@ -95,7 +82,13 @@ class RecommendedAppListApi(Resource):
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language_prefix = _resolve_language(args.language)
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
return RecommendedAppListResponse.model_validate(
|
||||
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
|
||||
@ -103,22 +96,6 @@ class RecommendedAppListApi(Resource):
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/learn-dify")
|
||||
class LearnDifyAppListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[LearnDifyAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language_prefix = _resolve_language(args.language)
|
||||
|
||||
return LearnDifyAppListResponse.model_validate(
|
||||
RecommendedAppService.get_learn_dify_apps(language_prefix),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
class RecommendedAppApi(Resource):
|
||||
@login_required
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
@ -40,12 +41,6 @@ def _mask_api_key(api_key: str) -> str:
|
||||
return api_key[:3] + "******" + api_key[-3:]
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class APIBasedExtensionResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -61,7 +56,7 @@ class APIBasedExtensionResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
|
||||
|
||||
@ -105,7 +105,8 @@ class FilePreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -1,142 +0,0 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(creator).strip() for creator in value if str(creator).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
@ -1,617 +0,0 @@
|
||||
# import logging
|
||||
# from collections.abc import Callable
|
||||
# from functools import wraps
|
||||
|
||||
# from flask import request
|
||||
# from flask_restx import Resource, fields, marshal, marshal_with
|
||||
# from sqlalchemy.orm import Session
|
||||
# from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
# from controllers.common.schema import register_schema_models
|
||||
# from controllers.console import console_ns
|
||||
# from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
# from controllers.console.app.workflow import (
|
||||
# RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
# workflow_model,
|
||||
# workflow_pagination_model,
|
||||
# )
|
||||
# from controllers.console.app.workflow_run import (
|
||||
# workflow_run_detail_model,
|
||||
# workflow_run_node_execution_list_model,
|
||||
# workflow_run_node_execution_model,
|
||||
# workflow_run_pagination_model,
|
||||
# )
|
||||
# from controllers.console.snippets.payloads import (
|
||||
# PublishWorkflowPayload,
|
||||
# SnippetDraftNodeRunPayload,
|
||||
# SnippetDraftRunPayload,
|
||||
# SnippetDraftSyncPayload,
|
||||
# SnippetIterationNodeRunPayload,
|
||||
# SnippetLoopNodeRunPayload,
|
||||
# SnippetWorkflowListQuery,
|
||||
# WorkflowRunQuery,
|
||||
# )
|
||||
# from controllers.console.wraps import (
|
||||
# account_initialization_required,
|
||||
# edit_permission_required,
|
||||
# setup_required,
|
||||
# )
|
||||
# from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
# from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
# from extensions.ext_database import db
|
||||
# from extensions.ext_redis import redis_client
|
||||
# from graphon.graph_engine.manager import GraphEngineManager
|
||||
# from libs import helper
|
||||
# from libs.helper import TimestampField
|
||||
# from libs.login import current_account_with_tenant, login_required
|
||||
# from models.snippet import CustomizedSnippet
|
||||
# from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
# from services.snippet_generate_service import SnippetGenerateService
|
||||
# from services.snippet_service import SnippetService
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# # Register Pydantic models with Swagger
|
||||
# register_schema_models(
|
||||
# console_ns,
|
||||
# SnippetDraftSyncPayload,
|
||||
# SnippetDraftNodeRunPayload,
|
||||
# SnippetDraftRunPayload,
|
||||
# SnippetIterationNodeRunPayload,
|
||||
# SnippetLoopNodeRunPayload,
|
||||
# SnippetWorkflowListQuery,
|
||||
# WorkflowRunQuery,
|
||||
# PublishWorkflowPayload,
|
||||
# )
|
||||
|
||||
|
||||
# snippet_workflow_model = console_ns.clone("SnippetWorkflow", workflow_model, {
|
||||
# "input_fields": fields.Raw(default=[]),
|
||||
# })
|
||||
|
||||
|
||||
# class SnippetNotFoundError(Exception):
|
||||
# """Snippet not found error."""
|
||||
|
||||
# pass
|
||||
|
||||
|
||||
# def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
# """Decorator to fetch and validate snippet access."""
|
||||
|
||||
# @wraps(view_func)
|
||||
# def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# if not kwargs.get("snippet_id"):
|
||||
# raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
# _, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# snippet_id = str(kwargs.get("snippet_id"))
|
||||
# del kwargs["snippet_id"]
|
||||
|
||||
# snippet = SnippetService.get_snippet_by_id(
|
||||
# snippet_id=snippet_id,
|
||||
# tenant_id=current_tenant_id,
|
||||
# )
|
||||
|
||||
# if not snippet:
|
||||
# raise NotFound("Snippet not found")
|
||||
|
||||
# kwargs["snippet"] = snippet
|
||||
|
||||
# return view_func(*args, **kwargs)
|
||||
|
||||
# return decorated_view
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
# class SnippetDraftWorkflowApi(Resource):
|
||||
# @console_ns.doc("get_snippet_draft_workflow")
|
||||
# @console_ns.response(200, "Draft workflow retrieved successfully", snippet_workflow_model)
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# @marshal_with(snippet_workflow_model)
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get draft workflow for snippet."""
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
# if not workflow:
|
||||
# raise DraftWorkflowNotExist()
|
||||
|
||||
# db.session.expunge(workflow)
|
||||
# workflow.conversation_variables = []
|
||||
# workflow.input_fields = snippet.input_fields_list
|
||||
# return workflow
|
||||
|
||||
# @console_ns.doc("sync_snippet_draft_workflow")
|
||||
# @console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
# @console_ns.response(200, "Draft workflow synced successfully")
|
||||
# @console_ns.response(400, "Hash mismatch")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet):
|
||||
# """Sync draft workflow for snippet."""
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
|
||||
# payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# try:
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.sync_draft_workflow(
|
||||
# snippet=snippet,
|
||||
# graph=payload.graph,
|
||||
# unique_hash=payload.hash,
|
||||
# account=current_user,
|
||||
# input_fields=payload.input_fields,
|
||||
# )
|
||||
# except WorkflowHashNotEqualError:
|
||||
# raise DraftWorkflowNotSync()
|
||||
# except ValueError as e:
|
||||
# return {"message": str(e)}, 400
|
||||
|
||||
# return {
|
||||
# "result": "success",
|
||||
# "hash": workflow.unique_hash,
|
||||
# "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
# class SnippetDraftConfigApi(Resource):
|
||||
# @console_ns.doc("get_snippet_draft_config")
|
||||
# @console_ns.response(200, "Draft config retrieved successfully")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get snippet draft workflow configuration limits."""
|
||||
# return {
|
||||
# "parallel_depth_limit": 3,
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
# class SnippetPublishedWorkflowApi(Resource):
|
||||
# @console_ns.doc("get_snippet_published_workflow")
|
||||
# @console_ns.response(200, "Published workflow retrieved successfully", snippet_workflow_model)
|
||||
# @console_ns.response(404, "Snippet not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# @marshal_with(snippet_workflow_model)
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get published workflow for snippet."""
|
||||
# if not snippet.is_published:
|
||||
# return None
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
# if workflow:
|
||||
# workflow.input_fields = snippet.input_fields_list
|
||||
|
||||
# return workflow
|
||||
|
||||
# @console_ns.doc("publish_snippet_workflow")
|
||||
# @console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
# @console_ns.response(200, "Workflow published successfully")
|
||||
# @console_ns.response(400, "No draft workflow found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet):
|
||||
# """Publish snippet workflow."""
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# snippet_service = SnippetService()
|
||||
|
||||
# with Session(db.engine) as session:
|
||||
# snippet = session.merge(snippet)
|
||||
# try:
|
||||
# workflow = snippet_service.publish_workflow(
|
||||
# session=session,
|
||||
# snippet=snippet,
|
||||
# account=current_user,
|
||||
# )
|
||||
# workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
# session.commit()
|
||||
# except ValueError as e:
|
||||
# return {"message": str(e)}, 400
|
||||
|
||||
# return {
|
||||
# "result": "success",
|
||||
# "created_at": workflow_created_at,
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
# class SnippetDefaultBlockConfigsApi(Resource):
|
||||
# @console_ns.doc("get_snippet_default_block_configs")
|
||||
# @console_ns.response(200, "Default block configs retrieved successfully")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get default block configurations for snippet workflow."""
|
||||
# snippet_service = SnippetService()
|
||||
# return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
# class SnippetPublishedAllWorkflowApi(Resource):
|
||||
# @console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
# @console_ns.doc("get_all_snippet_published_workflows")
|
||||
# @console_ns.doc(description="Get all published workflows for a snippet")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
# @console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get all published workflow versions for snippet."""
|
||||
# args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# with Session(db.engine) as session:
|
||||
# workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
# session=session,
|
||||
# snippet=snippet,
|
||||
# page=args.page,
|
||||
# limit=args.limit,
|
||||
# )
|
||||
# serialized_workflows = marshal(workflows, workflow_model)
|
||||
|
||||
# return {
|
||||
# "items": serialized_workflows,
|
||||
# "page": args.page,
|
||||
# "limit": args.limit,
|
||||
# "has_more": has_more,
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
|
||||
# class SnippetDraftWorkflowRestoreApi(Resource):
|
||||
# @console_ns.doc("restore_snippet_workflow_to_draft")
|
||||
# @console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
|
||||
# @console_ns.response(200, "Workflow restored successfully")
|
||||
# @console_ns.response(400, "Source workflow must be published")
|
||||
# @console_ns.response(404, "Workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, workflow_id: str):
|
||||
# """Restore a published snippet workflow version into the draft workflow."""
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# snippet_service = SnippetService()
|
||||
|
||||
# try:
|
||||
# workflow = snippet_service.restore_published_workflow_to_draft(
|
||||
# snippet=snippet,
|
||||
# workflow_id=workflow_id,
|
||||
# account=current_user,
|
||||
# )
|
||||
# except IsDraftWorkflowError as exc:
|
||||
# raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
# except WorkflowNotFoundError as exc:
|
||||
# raise NotFound(str(exc)) from exc
|
||||
# except ValueError as exc:
|
||||
# raise BadRequest(str(exc)) from exc
|
||||
|
||||
# return {
|
||||
# "result": "success",
|
||||
# "hash": workflow.unique_hash,
|
||||
# "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
# class SnippetWorkflowRunsApi(Resource):
|
||||
# @console_ns.doc("list_snippet_workflow_runs")
|
||||
# @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_pagination_model)
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """List workflow runs for snippet."""
|
||||
# query = WorkflowRunQuery.model_validate(
|
||||
# {
|
||||
# "last_id": request.args.get("last_id"),
|
||||
# "limit": request.args.get("limit", type=int, default=20),
|
||||
# }
|
||||
# )
|
||||
# args = {
|
||||
# "last_id": query.last_id,
|
||||
# "limit": query.limit,
|
||||
# }
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
# class SnippetWorkflowRunDetailApi(Resource):
|
||||
# @console_ns.doc("get_snippet_workflow_run_detail")
|
||||
# @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
# @console_ns.response(404, "Workflow run not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_detail_model)
|
||||
# def get(self, snippet: CustomizedSnippet, run_id):
|
||||
# """Get workflow run detail for snippet."""
|
||||
# run_id = str(run_id)
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
# if not workflow_run:
|
||||
# raise NotFound("Workflow run not found")
|
||||
|
||||
# return workflow_run
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
# class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
# @console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
# @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_node_execution_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet, run_id):
|
||||
# """List node executions for a workflow run."""
|
||||
# run_id = str(run_id)
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
# snippet=snippet,
|
||||
# run_id=run_id,
|
||||
# )
|
||||
|
||||
# return {"data": node_executions}
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
# class SnippetDraftNodeRunApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_node")
|
||||
# @console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
# @console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_node_execution_model)
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Run a single node in snippet draft workflow.
|
||||
|
||||
# Executes a specific node with provided inputs for single-step debugging.
|
||||
# Returns the node execution result including status, outputs, and timing.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# user_inputs = payload.inputs
|
||||
|
||||
# # Get draft workflow for file parsing
|
||||
# snippet_service = SnippetService()
|
||||
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if not draft_workflow:
|
||||
# raise NotFound("Draft workflow not found")
|
||||
|
||||
# files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
# workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
# snippet=snippet,
|
||||
# node_id=node_id,
|
||||
# user_inputs=user_inputs,
|
||||
# account=current_user,
|
||||
# query=payload.query,
|
||||
# files=files,
|
||||
# )
|
||||
|
||||
# return workflow_node_execution
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
# class SnippetDraftNodeLastRunApi(Resource):
|
||||
# @console_ns.doc("get_snippet_draft_node_last_run")
|
||||
# @console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
# @console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_node_execution_model)
|
||||
# def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
# Returns the most recent execution record for the given node,
|
||||
# including status, inputs, outputs, and timing information.
|
||||
# """
|
||||
# snippet_service = SnippetService()
|
||||
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if not draft_workflow:
|
||||
# raise NotFound("Draft workflow not found")
|
||||
|
||||
# node_exec = snippet_service.get_snippet_node_last_run(
|
||||
# snippet=snippet,
|
||||
# workflow=draft_workflow,
|
||||
# node_id=node_id,
|
||||
# )
|
||||
# if node_exec is None:
|
||||
# raise NotFound("Node last run not found")
|
||||
|
||||
# return node_exec
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
# class SnippetDraftRunIterationNodeApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_iteration_node")
|
||||
# @console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
# @console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Run a draft workflow iteration node for snippet.
|
||||
|
||||
# Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
# Returns an SSE event stream with iteration progress and results.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
# try:
|
||||
# response = SnippetGenerateService.generate_single_iteration(
|
||||
# snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
# )
|
||||
|
||||
# return helper.compact_generate_response(response)
|
||||
# except ValueError as e:
|
||||
# raise e
|
||||
# except Exception:
|
||||
# logger.exception("internal server error.")
|
||||
# raise InternalServerError()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
# class SnippetDraftRunLoopNodeApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_loop_node")
|
||||
# @console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
# @console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Run a draft workflow loop node for snippet.
|
||||
|
||||
# Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
# Returns an SSE event stream with loop progress and results.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# try:
|
||||
# response = SnippetGenerateService.generate_single_loop(
|
||||
# snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
# )
|
||||
|
||||
# return helper.compact_generate_response(response)
|
||||
# except ValueError as e:
|
||||
# raise e
|
||||
# except Exception:
|
||||
# logger.exception("internal server error.")
|
||||
# raise InternalServerError()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
# class SnippetDraftWorkflowRunApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_workflow")
|
||||
# @console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
# @console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet):
|
||||
# """
|
||||
# Run draft workflow for snippet.
|
||||
|
||||
# Executes the snippet's draft workflow with the provided inputs
|
||||
# and returns an SSE event stream with execution progress and results.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
|
||||
# payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
# args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# try:
|
||||
# response = SnippetGenerateService.generate(
|
||||
# snippet=snippet,
|
||||
# user=current_user,
|
||||
# args=args,
|
||||
# invoke_from=InvokeFrom.DEBUGGER,
|
||||
# streaming=True,
|
||||
# )
|
||||
|
||||
# return helper.compact_generate_response(response)
|
||||
# except ValueError as e:
|
||||
# raise e
|
||||
# except Exception:
|
||||
# logger.exception("internal server error.")
|
||||
# raise InternalServerError()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
# class SnippetWorkflowTaskStopApi(Resource):
|
||||
# @console_ns.doc("stop_snippet_workflow_task")
|
||||
# @console_ns.response(200, "Task stopped successfully")
|
||||
# @console_ns.response(404, "Snippet not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
# """
|
||||
# Stop a running snippet workflow task.
|
||||
|
||||
# Uses both the legacy stop flag mechanism and the graph engine
|
||||
# command channel for backward compatibility.
|
||||
# """
|
||||
# # Stop using both mechanisms for backward compatibility
|
||||
# # Legacy stop flag mechanism (without user check)
|
||||
# AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# # New graph engine command channel mechanism
|
||||
# GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
# return {"result": "success"}
|
||||
@ -1,316 +0,0 @@
|
||||
# """
|
||||
# Snippet draft workflow variable APIs.
|
||||
|
||||
# Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
# using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
# Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
# (`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
# reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
# Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
# """
|
||||
|
||||
# from collections.abc import Callable
|
||||
# from functools import wraps
|
||||
# from typing import Any
|
||||
|
||||
# from flask import Response, request
|
||||
# from flask_restx import Resource, marshal, marshal_with
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
# from controllers.console import console_ns
|
||||
# from controllers.console.app.error import DraftWorkflowNotExist
|
||||
# from controllers.console.app.workflow_draft_variable import (
|
||||
# WorkflowDraftVariableListQuery,
|
||||
# WorkflowDraftVariableUpdatePayload,
|
||||
# _ensure_variable_access,
|
||||
# _file_access_controller,
|
||||
# validate_node_id,
|
||||
# workflow_draft_variable_list_model,
|
||||
# workflow_draft_variable_list_without_value_model,
|
||||
# workflow_draft_variable_model,
|
||||
# )
|
||||
# from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
# from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
# from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
# from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
# from extensions.ext_database import db
|
||||
# from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
# from factories.variable_factory import build_segment_with_type
|
||||
# from graphon.variables.types import SegmentType
|
||||
# from libs.login import current_user, login_required
|
||||
# from models.snippet import CustomizedSnippet
|
||||
# from models.workflow import WorkflowDraftVariable
|
||||
# from services.snippet_service import SnippetService
|
||||
# from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
# _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
# {SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
# )
|
||||
|
||||
|
||||
# def _ensure_snippet_draft_variable_row_allowed(
|
||||
# *,
|
||||
# variable: WorkflowDraftVariable,
|
||||
# variable_id: str,
|
||||
# ) -> None:
|
||||
# """Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
# if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
# raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
# def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
# """Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# @wraps(f)
|
||||
# def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# return f(*args, **kwargs)
|
||||
|
||||
# return wrapper
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
# class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
# @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
# @console_ns.doc("get_snippet_workflow_variables")
|
||||
# @console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
# @console_ns.response(
|
||||
# 200,
|
||||
# "Workflow variables retrieved successfully",
|
||||
# workflow_draft_variable_list_without_value_model,
|
||||
# )
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
# args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
# raise DraftWorkflowNotExist()
|
||||
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
# workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
# app_id=snippet.id,
|
||||
# page=args.page,
|
||||
# limit=args.limit,
|
||||
# user_id=current_user.id,
|
||||
# exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
# )
|
||||
|
||||
# return workflow_vars
|
||||
|
||||
# @console_ns.doc("delete_snippet_workflow_variables")
|
||||
# @console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
# @console_ns.response(204, "Workflow variables deleted successfully")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def delete(self, snippet: CustomizedSnippet) -> Response:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
# db.session.commit()
|
||||
# return Response("", 204)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
# class SnippetNodeVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_node_variables")
|
||||
# @console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
# @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
# validate_node_id(node_id)
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
# node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
# return node_vars
|
||||
|
||||
# @console_ns.doc("delete_snippet_node_variables")
|
||||
# @console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
# @console_ns.response(204, "Node variables deleted successfully")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
# validate_node_id(node_id)
|
||||
# srv = WorkflowDraftVariableService(db.session())
|
||||
# srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
# db.session.commit()
|
||||
# return Response("", 204)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
# class SnippetVariableApi(Resource):
|
||||
# @console_ns.doc("get_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
# @console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_model)
|
||||
# def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
# return variable
|
||||
|
||||
# @console_ns.doc("update_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
# @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
# @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_model)
|
||||
# def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
# new_name = args_model.name
|
||||
# raw_value = args_model.value
|
||||
# if new_name is None and raw_value is None:
|
||||
# return variable
|
||||
|
||||
# new_value = None
|
||||
# if raw_value is not None:
|
||||
# if variable.value_type == SegmentType.FILE:
|
||||
# if not isinstance(raw_value, dict):
|
||||
# raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
# raw_value = build_from_mapping(
|
||||
# mapping=raw_value,
|
||||
# tenant_id=snippet.tenant_id,
|
||||
# access_controller=_file_access_controller,
|
||||
# )
|
||||
# elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
# if not isinstance(raw_value, list):
|
||||
# raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
# if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
# raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
# raw_value = build_from_mappings(
|
||||
# mappings=raw_value,
|
||||
# tenant_id=snippet.tenant_id,
|
||||
# access_controller=_file_access_controller,
|
||||
# )
|
||||
# new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
# draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
# db.session.commit()
|
||||
# return variable
|
||||
|
||||
# @console_ns.doc("delete_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
# @console_ns.response(204, "Variable deleted successfully")
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
# draft_var_srv.delete_variable(variable)
|
||||
# db.session.commit()
|
||||
# return Response("", 204)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
# class SnippetVariableResetApi(Resource):
|
||||
# @console_ns.doc("reset_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
# @console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
# @console_ns.response(204, "Variable reset (no content)")
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# snippet_service = SnippetService()
|
||||
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if draft_workflow is None:
|
||||
# raise NotFoundError(
|
||||
# f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
# )
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
# resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
# db.session.commit()
|
||||
# if resetted is None:
|
||||
# return Response("", 204)
|
||||
# return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
# class SnippetConversationVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_conversation_variables")
|
||||
# @console_ns.doc(
|
||||
# description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
# )
|
||||
# @console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
# return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
# class SnippetSystemVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_system_variables")
|
||||
# @console_ns.doc(
|
||||
# description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
# )
|
||||
# @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
# return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
# class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_environment_variables")
|
||||
# @console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
# @console_ns.response(200, "Environment variables retrieved successfully")
|
||||
# @console_ns.response(404, "Draft workflow not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if workflow is None:
|
||||
# raise DraftWorkflowNotExist()
|
||||
|
||||
# env_vars_list: list[dict[str, Any]] = []
|
||||
# for v in workflow.environment_variables:
|
||||
# env_vars_list.append(
|
||||
# {
|
||||
# "id": v.id,
|
||||
# "type": "env",
|
||||
# "name": v.name,
|
||||
# "description": v.description,
|
||||
# "selector": v.selector,
|
||||
# "value_type": v.value_type.exposed_type().value,
|
||||
# "value": v.value,
|
||||
# "edited": False,
|
||||
# "visible": True,
|
||||
# "editable": True,
|
||||
# }
|
||||
# )
|
||||
|
||||
# return {"items": env_vars_list}
|
||||
@ -25,6 +25,10 @@ class TagBasePayload(BaseModel):
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagUpdateRequestPayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
@ -68,6 +72,7 @@ class TagResponse(ResponseModel):
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagUpdateRequestPayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
@ -118,7 +123,7 @@ class TagListApi(Resource):
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@console_ns.expect(console_ns.models[TagUpdateRequestPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -129,8 +134,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
|
||||
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
@ -185,12 +185,6 @@ def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class AccountIntegrateResponse(ResponseModel):
|
||||
provider: str
|
||||
created_at: int | None = None
|
||||
@ -200,7 +194,7 @@ class AccountIntegrateResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AccountIntegrateListResponse(ResponseModel):
|
||||
@ -220,7 +214,7 @@ class EducationStatusResponse(ResponseModel):
|
||||
@field_validator("expire_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class EducationAutocompleteResponse(ResponseModel):
|
||||
|
||||
@ -30,14 +30,13 @@ from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
emails: list[str] = Field(default_factory=list)
|
||||
role: str
|
||||
role: TenantAccountRole
|
||||
language: str | None = None
|
||||
|
||||
|
||||
@ -71,18 +70,6 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _serialize_member_roles(current_role: str | None, member_roles: list[enterprise_rbac_service.MemberRoleSummary]) -> list[dict[str, str]]:
|
||||
if member_roles:
|
||||
return [{"id": role.id, "name": role.name} for role in member_roles]
|
||||
if current_role:
|
||||
return [{"id": current_role, "name": current_role}]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_enum_value(value: object) -> str:
|
||||
normalized = getattr(value, "value", value)
|
||||
return str(normalized) if normalized is not None else ""
|
||||
|
||||
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
|
||||
if role != TenantAccountRole.DATASET_OPERATOR:
|
||||
return True
|
||||
@ -102,36 +89,7 @@ class MemberListApi(Resource):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
member_ids = [member.id for member in members]
|
||||
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
|
||||
str(current_user.current_tenant.id),
|
||||
current_user.id,
|
||||
member_ids,
|
||||
)
|
||||
roles_map = {item.account_id: item.roles for item in member_roles}
|
||||
else:
|
||||
roles_map = {}
|
||||
|
||||
serialized_members = []
|
||||
for member in members:
|
||||
current_role = _normalize_enum_value(member.current_role)
|
||||
serialized_members.append(
|
||||
{
|
||||
"id": member.id,
|
||||
"name": member.name,
|
||||
"email": member.email,
|
||||
"avatar": member.avatar,
|
||||
"last_login_at": member.last_login_at,
|
||||
"last_active_at": member.last_active_at,
|
||||
"created_at": member.created_at,
|
||||
"role": current_role,
|
||||
"roles": _serialize_member_roles(current_role, roles_map.get(member.id, [])),
|
||||
"status": _normalize_enum_value(member.status),
|
||||
}
|
||||
)
|
||||
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(serialized_members)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -152,9 +110,8 @@ class MemberInviteEmailApi(Resource):
|
||||
invitee_emails = args.emails
|
||||
invitee_role = args.role
|
||||
interface_language = args.language
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
if not TenantAccountRole.is_valid_role(invitee_role) or not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import io
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model, register_enum_models, register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
@ -23,14 +23,6 @@ from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class AutoUpgradeSettingsResponse(TypedDict):
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting
|
||||
upgrade_time_of_day: int
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode
|
||||
exclude_plugins: list[str]
|
||||
include_plugins: list[str]
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
@ -94,8 +86,8 @@ class ParserUninstall(BaseModel):
|
||||
|
||||
|
||||
class ParserPermissionChange(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||
install_permission: TenantPluginPermission.InstallPermission
|
||||
debug_permission: TenantPluginPermission.DebugPermission
|
||||
|
||||
|
||||
class ParserDynamicOptions(BaseModel):
|
||||
@ -131,22 +123,13 @@ class PluginAutoUpgradeSettingsPayload(BaseModel):
|
||||
include_plugins: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ParserAutoUpgradeChange(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory
|
||||
class ParserPreferencesChange(BaseModel):
|
||||
permission: PluginPermissionSettingsPayload
|
||||
auto_upgrade: PluginAutoUpgradeSettingsPayload
|
||||
|
||||
|
||||
class ParserAutoUpgradeFetch(BaseModel):
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory
|
||||
|
||||
|
||||
class ParserExcludePlugin(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
plugin_id: str
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory
|
||||
|
||||
|
||||
class ParserReadme(BaseModel):
|
||||
@ -173,8 +156,7 @@ register_schema_models(
|
||||
ParserPermissionChange,
|
||||
ParserDynamicOptions,
|
||||
ParserDynamicOptionsWithCredentials,
|
||||
ParserAutoUpgradeChange,
|
||||
ParserAutoUpgradeFetch,
|
||||
ParserPreferencesChange,
|
||||
ParserExcludePlugin,
|
||||
ParserReadme,
|
||||
)
|
||||
@ -182,36 +164,12 @@ register_schema_models(
|
||||
register_enum_models(
|
||||
console_ns,
|
||||
TenantPluginPermission.DebugPermission,
|
||||
TenantPluginAutoUpgradeStrategy.PluginCategory,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
TenantPluginPermission.InstallPermission,
|
||||
)
|
||||
|
||||
|
||||
def _default_auto_upgrade_settings(
|
||||
tenant_id: str,
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory,
|
||||
) -> AutoUpgradeSettingsResponse:
|
||||
return {
|
||||
"strategy_setting": PluginAutoUpgradeService.default_strategy_setting_for_category(category),
|
||||
"upgrade_time_of_day": PluginAutoUpgradeService.default_upgrade_time_of_day(tenant_id),
|
||||
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
"exclude_plugins": [],
|
||||
"include_plugins": [],
|
||||
}
|
||||
|
||||
|
||||
def _auto_upgrade_settings_to_dict(strategy: TenantPluginAutoUpgradeStrategy) -> AutoUpgradeSettingsResponse:
|
||||
return {
|
||||
"strategy_setting": strategy.strategy_setting,
|
||||
"upgrade_time_of_day": strategy.upgrade_time_of_day,
|
||||
"upgrade_mode": strategy.upgrade_mode,
|
||||
"exclude_plugins": strategy.exclude_plugins,
|
||||
"include_plugins": strategy.include_plugins,
|
||||
}
|
||||
|
||||
|
||||
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
|
||||
"""
|
||||
Read the uploaded file and validate its actual size before delegating to the plugin service.
|
||||
@ -659,13 +617,11 @@ class PluginChangePermissionApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
set_permission_result = PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
)
|
||||
if not set_permission_result:
|
||||
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
|
||||
|
||||
return jsonable_encoder({"success": True})
|
||||
return {
|
||||
"success": PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/fetch")
|
||||
@ -754,9 +710,9 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/auto-upgrade/change")
|
||||
class PluginChangeAutoUpgradeApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserAutoUpgradeChange.__name__])
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -765,17 +721,38 @@ class PluginChangeAutoUpgradeApi(Resource):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = ParserAutoUpgradeChange.model_validate(console_ns.payload)
|
||||
args = ParserPreferencesChange.model_validate(console_ns.payload)
|
||||
|
||||
permission = args.permission
|
||||
|
||||
install_permission = permission.install_permission
|
||||
debug_permission = permission.debug_permission
|
||||
|
||||
auto_upgrade = args.auto_upgrade
|
||||
|
||||
strategy_setting = auto_upgrade.strategy_setting
|
||||
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
|
||||
upgrade_mode = auto_upgrade.upgrade_mode
|
||||
exclude_plugins = auto_upgrade.exclude_plugins
|
||||
include_plugins = auto_upgrade.include_plugins
|
||||
|
||||
# set permission
|
||||
set_permission_result = PluginPermissionService.change_permission(
|
||||
tenant_id,
|
||||
install_permission,
|
||||
debug_permission,
|
||||
)
|
||||
if not set_permission_result:
|
||||
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
|
||||
|
||||
# set auto upgrade strategy
|
||||
set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
|
||||
tenant_id,
|
||||
auto_upgrade.strategy_setting,
|
||||
auto_upgrade.upgrade_time_of_day,
|
||||
auto_upgrade.upgrade_mode,
|
||||
auto_upgrade.exclude_plugins,
|
||||
auto_upgrade.include_plugins,
|
||||
category=args.category,
|
||||
strategy_setting,
|
||||
upgrade_time_of_day,
|
||||
upgrade_mode,
|
||||
exclude_plugins,
|
||||
include_plugins,
|
||||
)
|
||||
if not set_auto_upgrade_strategy_result:
|
||||
return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})
|
||||
@ -783,32 +760,46 @@ class PluginChangeAutoUpgradeApi(Resource):
|
||||
return jsonable_encoder({"success": True})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/auto-upgrade/fetch")
|
||||
class PluginFetchAutoUpgradeApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(ParserAutoUpgradeFetch))
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
|
||||
class PluginFetchPreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserAutoUpgradeFetch.model_validate(request.args.to_dict(flat=True))
|
||||
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id, args.category)
|
||||
auto_upgrade_dict = (
|
||||
_auto_upgrade_settings_to_dict(auto_upgrade)
|
||||
if auto_upgrade
|
||||
else _default_auto_upgrade_settings(tenant_id, args.category)
|
||||
)
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
permission_dict = {
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
}
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"category": args.category,
|
||||
"auto_upgrade": auto_upgrade_dict,
|
||||
if permission:
|
||||
permission_dict["install_permission"] = permission.install_permission
|
||||
permission_dict["debug_permission"] = permission.debug_permission
|
||||
|
||||
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
|
||||
auto_upgrade_dict = {
|
||||
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
"upgrade_time_of_day": 0,
|
||||
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
"exclude_plugins": [],
|
||||
"include_plugins": [],
|
||||
}
|
||||
|
||||
if auto_upgrade:
|
||||
auto_upgrade_dict = {
|
||||
"strategy_setting": auto_upgrade.strategy_setting,
|
||||
"upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
|
||||
"upgrade_mode": auto_upgrade.upgrade_mode,
|
||||
"exclude_plugins": auto_upgrade.exclude_plugins,
|
||||
"include_plugins": auto_upgrade.include_plugins,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/auto-upgrade/exclude")
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
|
||||
@setup_required
|
||||
@ -820,9 +811,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder(
|
||||
{"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id, args.category)}
|
||||
)
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
|
||||
@ -1,614 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.enterprise import rbac_service as svc
|
||||
|
||||
|
||||
_LEGACY_ROLE_PERMISSION_KEYS: dict[str, list[str]] = {
|
||||
# This is a compatibility projection from the pre-RBAC workspace roles into
|
||||
# the 2.0 permission matrix documented in "权限整理2.0". It intentionally
|
||||
# models the product-facing role surface for the new RBAC UI instead of the
|
||||
# legacy backend's exact hard-authorization checks.
|
||||
"owner": [
|
||||
*svc._LEGACY_WORKSPACE_OWNER_KEYS,
|
||||
*svc._LEGACY_APP_OWNER_KEYS,
|
||||
*svc._LEGACY_DATASET_OWNER_KEYS,
|
||||
],
|
||||
"admin": [
|
||||
*svc._LEGACY_WORKSPACE_ADMIN_KEYS,
|
||||
*svc._LEGACY_APP_ADMIN_KEYS,
|
||||
*svc._LEGACY_DATASET_ADMIN_KEYS,
|
||||
],
|
||||
"editor": [
|
||||
*svc._LEGACY_WORKSPACE_EDITOR_KEYS,
|
||||
*svc._LEGACY_APP_EDITOR_KEYS,
|
||||
*svc._LEGACY_DATASET_EDITOR_KEYS,
|
||||
],
|
||||
"normal": [
|
||||
*svc._LEGACY_WORKSPACE_NORMAL_KEYS,
|
||||
*svc._LEGACY_APP_NORMAL_KEYS,
|
||||
],
|
||||
"dataset_operator": [
|
||||
*svc._LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS,
|
||||
*svc._LEGACY_DATASET_DATASET_OPERATOR_KEYS,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _current_ids() -> tuple[str, str]:
|
||||
"""Return ``(tenant_id, account_id)`` for the authenticated user, or
|
||||
raise a 404 when no tenant is associated with the session.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not tenant_id:
|
||||
raise NotFound("Current workspace not found")
|
||||
return tenant_id, user.id
|
||||
|
||||
|
||||
def _payload(model: type[BaseModel]) -> Any:
|
||||
"""Validate the JSON body against ``model`` or raise ``ValidationError``.
|
||||
|
||||
``ValidationError`` bubbles up as HTTP 400 thanks to
|
||||
``controllers/common/helpers.py`` error handling.
|
||||
"""
|
||||
try:
|
||||
return model.model_validate(console_ns.payload or {})
|
||||
except ValidationError as exc:
|
||||
# Re-raise as-is so the upstream error handler renders a 400.
|
||||
raise exc
|
||||
|
||||
|
||||
def _dump(model: BaseModel) -> dict[str, Any]:
|
||||
return model.model_dump(mode="json")
|
||||
|
||||
|
||||
class _PaginationQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
page_number: int | None = Field(default=None, ge=1, validation_alias=AliasChoices("page", "page_number"))
|
||||
results_per_page: int | None = Field(
|
||||
default=None, ge=1, le=100, validation_alias=AliasChoices("limit", "results_per_page")
|
||||
)
|
||||
reverse: bool | None = None
|
||||
|
||||
def to_inner_options(self) -> svc.ListOption:
|
||||
return svc.ListOption.model_validate(self.model_dump())
|
||||
|
||||
|
||||
class _RolesListQuery(_PaginationQuery):
|
||||
include_owner: int = Field(default=0, ge=0, le=1)
|
||||
|
||||
|
||||
def _pagination_options() -> svc.ListOption:
|
||||
return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options()
|
||||
|
||||
|
||||
def _filter_out_owner(paginated: svc.Paginated[svc.RBACRole]) -> svc.Paginated[svc.RBACRole]:
|
||||
filtered = [r for r in paginated.data if r.name not in {"所有者", "owner"}]
|
||||
return svc.Paginated[svc.RBACRole](
|
||||
data=filtered,
|
||||
pagination=paginated.pagination,
|
||||
)
|
||||
|
||||
|
||||
def _legacy_workspace_roles(options: svc.ListOption | None = None) -> svc.Paginated[svc.RBACRole]:
|
||||
"""Return the built-in legacy workspace roles in the RBAC list shape.
|
||||
|
||||
This keeps the new `/rbac/roles` endpoint compatible with the original
|
||||
Dify role model when enterprise RBAC is disabled.
|
||||
"""
|
||||
|
||||
legacy_roles = [
|
||||
svc.RBACRole(
|
||||
id=role_name,
|
||||
tenant_id="",
|
||||
type=svc.RBACRoleType.WORKSPACE.value,
|
||||
category="global_system_default",
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
|
||||
role_tag="owner" if role_name == "owner" else "",
|
||||
)
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
|
||||
]
|
||||
|
||||
page_number = options.page_number if options and options.page_number is not None else 1
|
||||
results_per_page = options.results_per_page if options and options.results_per_page is not None else len(legacy_roles)
|
||||
reverse = options.reverse if options and options.reverse is not None else False
|
||||
|
||||
ordered_roles = list(reversed(legacy_roles)) if reverse else legacy_roles
|
||||
start = max(page_number - 1, 0) * results_per_page
|
||||
end = start + results_per_page
|
||||
paged_roles = ordered_roles[start:end]
|
||||
total_count = len(legacy_roles)
|
||||
total_pages = (total_count + results_per_page - 1) // results_per_page if results_per_page > 0 else 0
|
||||
|
||||
return svc.Paginated[svc.RBACRole](
|
||||
data=paged_roles,
|
||||
pagination=svc.Pagination(
|
||||
total_count=total_count,
|
||||
per_page=results_per_page,
|
||||
current_page=page_number,
|
||||
total_pages=total_pages,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission catalogs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
|
||||
class RBACWorkspaceCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app")
|
||||
class RBACAppCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.app(tenant_id, account_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset")
|
||||
class RBACDatasetCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Roles.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RoleUpsertRequest(BaseModel):
|
||||
"""Accepts the payload sent by the Create/Edit Role dialog."""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
def to_mutation(self) -> svc.RoleMutation:
|
||||
return svc.RoleMutation(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
permission_keys=list(self.permission_keys),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles")
|
||||
class RBACRolesApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
query = _RolesListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
options = query.to_inner_options()
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
result = _legacy_workspace_roles(options)
|
||||
else:
|
||||
result = svc.RBACService.Roles.list(tenant_id, account_id, options=options)
|
||||
if query.include_owner == 0:
|
||||
result = _filter_out_owner(result)
|
||||
|
||||
data = []
|
||||
for role in result.data:
|
||||
if role.name in {"所有者", "owner"}:
|
||||
role.role_tag = "owner"
|
||||
else:
|
||||
role.role_tag = ""
|
||||
data.append(role)
|
||||
result.data = data
|
||||
return _dump(result)
|
||||
|
||||
@login_required
|
||||
def post(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_RoleUpsertRequest)
|
||||
role = svc.RBACService.Roles.create(tenant_id, account_id, request.to_mutation())
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>")
|
||||
class RBACRoleItemApi(Resource):
|
||||
@login_required
|
||||
def get(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_RoleUpsertRequest)
|
||||
role = svc.RBACService.Roles.update(tenant_id, account_id, str(role_id), request.to_mutation())
|
||||
return _dump(role)
|
||||
|
||||
@login_required
|
||||
def delete(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id))
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/copy")
|
||||
class RBACRoleCopyApi(Resource):
|
||||
@login_required
|
||||
def post(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
role = svc.RBACService.Roles.copy(tenant_id, account_id, str(role_id))
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access policies (tenant-level permission sets).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AccessPolicyCreateRequest(BaseModel):
|
||||
name: str
|
||||
resource_type: svc.RBACResourceType
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
|
||||
class _AccessPolicyUpdateRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies")
|
||||
class RBACAccessPoliciesApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
# `resource_type` is exposed as a query argument so the UI can show
|
||||
# only app-scoped or only dataset-scoped permission sets.
|
||||
resource_type = request.args.get("resource_type") or None
|
||||
return _dump(
|
||||
svc.RBACService.AccessPolicies.list(
|
||||
tenant_id,
|
||||
account_id,
|
||||
resource_type=resource_type,
|
||||
options=_pagination_options(),
|
||||
)
|
||||
)
|
||||
|
||||
@login_required
|
||||
def post(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_AccessPolicyCreateRequest)
|
||||
policy = svc.RBACService.AccessPolicies.create(
|
||||
tenant_id,
|
||||
account_id,
|
||||
svc.AccessPolicyCreate(
|
||||
name=request.name,
|
||||
resource_type=request.resource_type,
|
||||
description=request.description,
|
||||
permission_keys=list(request.permission_keys),
|
||||
),
|
||||
)
|
||||
return _dump(policy), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>")
|
||||
class RBACAccessPolicyItemApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_AccessPolicyUpdateRequest)
|
||||
policy = svc.RBACService.AccessPolicies.update(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.AccessPolicyUpdate(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
permission_keys=list(request.permission_keys),
|
||||
),
|
||||
)
|
||||
return _dump(policy)
|
||||
|
||||
@login_required
|
||||
def delete(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id))
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>/copy")
|
||||
class RBACAccessPolicyCopyApi(Resource):
|
||||
@login_required
|
||||
def post(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id))
|
||||
return _dump(policy), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-app access (App Access Config).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplaceBindingsRequest(BaseModel):
|
||||
role_ids: list[str] = []
|
||||
account_ids: list[str] = []
|
||||
|
||||
@field_validator("role_ids", "account_ids", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bindings(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/my-permissions")
|
||||
class RBACMyPermissionsApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.MyPermissions.get(
|
||||
tenant_id,
|
||||
account_id,
|
||||
app_id=request.args.get("app_id") or None,
|
||||
dataset_id=request.args.get("dataset_id") or None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policy")
|
||||
class RBACAppMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACAppRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACAppMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACAppBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.replace_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(app_id),
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-dataset access (Knowledge Base Access Config).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policy")
|
||||
class RBACDatasetMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACDatasetRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.list_role_bindings(
|
||||
tenant_id, account_id, str(dataset_id), str(policy_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACDatasetBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.replace_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(dataset_id),
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/member-bindings"
|
||||
)
|
||||
class RBACDatasetMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.list_member_bindings(
|
||||
tenant_id, account_id, str(dataset_id), str(policy_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace-level access (Settings > Access Rules).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
|
||||
class RBACWorkspaceAppMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id, options=options))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACWorkspaceAppRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACWorkspaceAppBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.replace_app_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACWorkspaceAppMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy")
|
||||
class RBACWorkspaceDatasetMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id, options=options))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACWorkspaceDatasetRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACWorkspaceDatasetBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.replace_dataset_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACWorkspaceDatasetMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Member ↔ role bindings (Settings > Members > Assign roles).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplaceMemberRolesRequest(BaseModel):
|
||||
role_ids: list[str] = []
|
||||
|
||||
@field_validator("role_ids", mode="before")
|
||||
@classmethod
|
||||
def _coerce_role_ids(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/members/<uuid:member_id>/rbac-roles")
|
||||
class RBACMemberRolesApi(Resource):
|
||||
@login_required
|
||||
def get(self, member_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, member_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceMemberRolesRequest)
|
||||
return _dump(
|
||||
svc.RBACService.MemberRoles.replace(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(member_id),
|
||||
role_ids=list(request.role_ids),
|
||||
)
|
||||
)
|
||||
@ -1,380 +0,0 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
IncludeSecretQuery,
|
||||
SnippetImportPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import SnippetType
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.snippet_dsl_service import SnippetDslService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
SnippetImportPayload,
|
||||
IncludeSecretQuery,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query_params = request.args.to_dict()
|
||||
query = SnippetListQuery.model_validate(query_params)
|
||||
|
||||
snippets, total, has_more = SnippetService.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
is_published=query.is_published,
|
||||
creators=query.creators,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
|
||||
class CustomizedSnippetExportApi(Resource):
|
||||
@console_ns.doc("export_customized_snippet")
|
||||
@console_ns.doc(description="Export snippet configuration as DSL")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
|
||||
@console_ns.response(200, "Snippet exported successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Export snippet as DSL."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
# Get include_secret parameter
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = SnippetDslService(session)
|
||||
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
|
||||
|
||||
# Set filename with .snippet extension
|
||||
filename = f"{snippet.name}.snippet"
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
response = Response(
|
||||
result,
|
||||
mimetype="application/x-yaml",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/x-yaml"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports")
|
||||
class CustomizedSnippetImportApi(Resource):
|
||||
@console_ns.doc("import_customized_snippet")
|
||||
@console_ns.doc(description="Import snippet from DSL")
|
||||
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
|
||||
@console_ns.response(200, "Snippet imported successfully")
|
||||
@console_ns.response(202, "Import pending confirmation")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Import snippet from DSL."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.import_snippet(
|
||||
account=current_user,
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
snippet_id=payload.snippet_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
|
||||
class CustomizedSnippetImportConfirmApi(Resource):
|
||||
@console_ns.doc("confirm_snippet_import")
|
||||
@console_ns.doc(description="Confirm a pending snippet import")
|
||||
@console_ns.doc(params={"import_id": "Import ID to confirm"})
|
||||
@console_ns.response(200, "Import confirmed successfully")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
"""Confirm a pending snippet import."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.confirm_import(import_id=import_id, account=current_user)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
|
||||
class CustomizedSnippetCheckDependenciesApi(Resource):
|
||||
@console_ns.doc("check_snippet_dependencies")
|
||||
@console_ns.doc(description="Check dependencies for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Dependencies checked successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Check dependencies for a snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.check_dependencies(snippet=snippet)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
|
||||
class CustomizedSnippetUseCountIncrementApi(Resource):
|
||||
@console_ns.doc("increment_snippet_use_count")
|
||||
@console_ns.doc(description="Increment snippet use count by 1")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Use count incremented successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, snippet_id: str):
|
||||
"""Increment snippet use count when it is inserted into a workflow."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.increment_use_count(session=session, snippet=snippet)
|
||||
session.commit()
|
||||
session.refresh(snippet)
|
||||
|
||||
return {"result": "success", "use_count": snippet.use_count}, 200
|
||||
@ -29,7 +29,7 @@ from controllers.console.wraps import (
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
@ -86,9 +86,7 @@ class TenantInfoResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None):
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
|
||||
@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
Get current user.
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
As a result, it could only be considered as an end user id. Even when a
|
||||
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
|
||||
tenant cannot bind another tenant's user record into the plugin request
|
||||
context.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
user_model = session.get(EndUser, user_id)
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
|
||||
@ -22,7 +22,7 @@ from fields.conversation_fields import (
|
||||
SimpleConversation,
|
||||
)
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.helper import UUIDStrOrEmpty, to_timestamp
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
@ -115,9 +115,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
|
||||
|
||||
@ -7,18 +7,18 @@ paused human input forms in workflow/chatflow runs.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
@ -28,30 +28,14 @@ logger = logging.getLogger(__name__)
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -39,6 +39,7 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import to_timestamp
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -68,12 +69,6 @@ class WorkflowLogQuery(BaseModel):
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _enum_value(value):
|
||||
return getattr(value, "value", value)
|
||||
|
||||
@ -109,7 +104,7 @@ class WorkflowRunResponse(ResponseModel):
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
@ -133,31 +128,13 @@ class WorkflowRunForLogResponse(ResponseModel):
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
|
||||
name: str
|
||||
value: Any = None
|
||||
details: dict[str, Any] | None = None
|
||||
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
|
||||
default=None,
|
||||
validation_alias="node_info",
|
||||
serialization_alias="nodeInfo",
|
||||
)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: dict | list | str | int | float | bool | None = None
|
||||
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
@ -172,12 +149,7 @@ class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
@field_validator("evaluation", mode="before")
|
||||
@classmethod
|
||||
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
|
||||
return value or []
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
@ -192,8 +164,6 @@ register_schema_models(
|
||||
service_api_ns,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowAppLogEvaluationNodeInfoResponse,
|
||||
WorkflowAppLogEvaluationItemResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
)
|
||||
|
||||
@ -31,7 +31,9 @@ from services.tag_service import (
|
||||
TagBindingCreatePayload,
|
||||
TagBindingDeletePayload,
|
||||
TagService,
|
||||
UpdateTagPayload,
|
||||
)
|
||||
from services.tag_service import (
|
||||
UpdateTagPayload as UpdateTagServicePayload,
|
||||
)
|
||||
|
||||
register_enum_models(service_api_ns, DatasetPermissionEnum)
|
||||
@ -556,7 +558,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
|
||||
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ from . import (
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_file_upload,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
@ -47,7 +46,6 @@ __all__ = [
|
||||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_file_upload",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
|
||||
@ -1,181 +0,0 @@
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
RemoteFileUploadError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, FileWithSignedUrl
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
from services.human_input_file_upload_service import (
|
||||
HITL_UPLOAD_TOKEN_PREFIX,
|
||||
HumanInputFileUploadService,
|
||||
InvalidUploadTokenError,
|
||||
)
|
||||
|
||||
|
||||
class InvalidUploadTokenBadRequestError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Invalid upload token."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Upload token is required."
|
||||
code = 401
|
||||
|
||||
|
||||
class InvalidUploadTokenForbiddenError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Upload token is invalid or expired."
|
||||
code = 403
|
||||
|
||||
|
||||
class HumanInputRemoteFileUploadPayload(BaseModel):
|
||||
url: HttpUrl = Field(description="Remote file URL")
|
||||
|
||||
|
||||
register_schema_models(web_ns, HumanInputRemoteFileUploadPayload, FileResponse, FileWithSignedUrl)
|
||||
|
||||
|
||||
def _extract_hitl_upload_token() -> str:
|
||||
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
|
||||
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization is None:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
|
||||
parts = authorization.split()
|
||||
if len(parts) != 2:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
|
||||
scheme, token = parts
|
||||
if scheme.lower() != "bearer":
|
||||
raise InvalidUploadTokenBadRequestError()
|
||||
if not token:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
|
||||
raise InvalidUploadTokenBadRequestError()
|
||||
return token
|
||||
|
||||
|
||||
def _validate_context(service: HumanInputFileUploadService, token: str):
|
||||
try:
|
||||
return service.validate_upload_token(token)
|
||||
except InvalidUploadTokenError as exc:
|
||||
raise InvalidUploadTokenForbiddenError() from exc
|
||||
|
||||
|
||||
def _parse_local_upload_file():
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
|
||||
raise FilenameNotExistsError()
|
||||
|
||||
return file
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/files/upload")
|
||||
class HumanInputFileUploadApi(Resource):
|
||||
def post(self):
|
||||
"""Upload one local file for a HITL human input form."""
|
||||
|
||||
token = _extract_hitl_upload_token()
|
||||
upload_service = HumanInputFileUploadService(db.engine)
|
||||
context = _validate_context(upload_service, token)
|
||||
file = _parse_local_upload_file()
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename or "",
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=context.owner,
|
||||
source=None,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError() from exc
|
||||
|
||||
upload_service.record_upload_file(context=context, file_id=upload_file.id)
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/files/remote-upload")
|
||||
class HumanInputRemoteFileUploadApi(Resource):
|
||||
def post(self):
|
||||
"""Upload one remote URL file for a HITL human input form."""
|
||||
|
||||
token = _extract_hitl_upload_token()
|
||||
upload_service = HumanInputFileUploadService(db.engine)
|
||||
context = _validate_context(upload_service, token)
|
||||
payload = HumanInputRemoteFileUploadPayload.model_validate(request.get_json(silent=True) or {})
|
||||
url = str(payload.url)
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as exc:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError()
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=context.owner,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError() from exc
|
||||
|
||||
upload_service.record_upload_file(context=context, file_id=upload_file.id)
|
||||
payload1 = FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
return payload1.model_dump(mode="json"), 201
|
||||
@ -4,39 +4,27 @@ Web App Human Input Form APIs.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
from libs.helper import RateLimiter, extract_remote_ip, to_timestamp
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from services.human_input_file_upload_service import HumanInputFileUploadService
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputUploadTokenResponse(BaseModel):
|
||||
upload_token: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
register_schema_models(web_ns, HumanInputUploadTokenResponse)
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
@ -47,27 +35,6 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_upload_token_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
@ -85,42 +52,15 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
}
|
||||
if site_payload is not None:
|
||||
payload["site"] = site_payload
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
|
||||
class HumanInputFormUploadTokenApi(Resource):
|
||||
"""API for issuing HITL upload tokens for active human input forms."""
|
||||
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Issue an upload token for a human input form.
|
||||
|
||||
POST /api/form/human_input/<form_token>/upload-token
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
try:
|
||||
token = HumanInputFileUploadService(db.engine).issue_upload_token(form_token)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
response = HumanInputUploadTokenResponse(
|
||||
upload_token=token.upload_token,
|
||||
expires_at=_to_timestamp(token.expires_at),
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@ -245,49 +245,50 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
status=stream_response.data.status,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
),
|
||||
)
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
extras["metadata"] = stream_response.metadata
|
||||
match stream_response:
|
||||
case ErrorStreamResponse():
|
||||
raise stream_response.err
|
||||
case HumanInputRequiredResponse():
|
||||
human_input_responses.append(stream_response)
|
||||
case WorkflowPauseStreamResponse():
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
status=stream_response.data.status,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
),
|
||||
)
|
||||
case MessageEndStreamResponse():
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
extras["metadata"] = stream_response.metadata
|
||||
|
||||
return ChatbotAppBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.answer,
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
return ChatbotAppBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.answer,
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
case _:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
@ -425,11 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
|
||||
message.workflow_run_id = run_id
|
||||
session.execute(update(Message).where(Message.id == self._message_id).values(workflow_run_id=run_id))
|
||||
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
||||
@ -178,7 +178,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
if not isinstance(agent_mode, dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
||||
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
|
||||
# FIXME(-LAN-): Cast needed because static checkers do not narrow this dict value.
|
||||
agent_mode = cast(dict[str, Any], agent_mode)
|
||||
|
||||
if "enabled" not in agent_mode or not agent_mode["enabled"]:
|
||||
|
||||
@ -408,19 +408,17 @@ class WorkflowResponseConverter:
|
||||
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
|
||||
) -> HumanInputFormFilledResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
data = HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
return HumanInputFormFilledResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
),
|
||||
)
|
||||
if event.submitted_data is not None:
|
||||
runtime_type_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
data.submitted_data = runtime_type_converter.value_to_json_encodable_recursive(event.submitted_data)
|
||||
|
||||
return HumanInputFormFilledResponse(task_id=task_id, workflow_run_id=run_id, data=data)
|
||||
|
||||
def human_input_form_timeout_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
|
||||
|
||||
@ -45,20 +45,24 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
match sub_stream_response:
|
||||
case PingStreamResponse():
|
||||
yield "ping"
|
||||
continue
|
||||
case ErrorStreamResponse():
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
case _:
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -74,20 +78,28 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
match sub_stream_response:
|
||||
case PingStreamResponse():
|
||||
yield "ping"
|
||||
continue
|
||||
case ErrorStreamResponse():
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
case NodeStartStreamResponse() | NodeFinishStreamResponse():
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
case _:
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -58,25 +58,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
|
||||
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
|
||||
if workflow.kind_or_standard != "snippet":
|
||||
return workflow
|
||||
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == workflow.app_id,
|
||||
CustomizedSnippet.tenant_id == workflow.tenant_id,
|
||||
)
|
||||
)
|
||||
if snippet is None:
|
||||
return workflow
|
||||
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
|
||||
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
@ -594,8 +575,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
|
||||
@ -10,7 +10,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.snippet_start import get_compatible_start_aliases
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -116,15 +115,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
)
|
||||
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
|
||||
add_node_inputs_to_pool(
|
||||
variable_pool,
|
||||
node_id=root_node_id,
|
||||
inputs=inputs,
|
||||
aliases=get_compatible_start_aliases(
|
||||
workflow_kind=getattr(self._workflow, "kind_or_standard", None),
|
||||
root_node_id=root_node_id,
|
||||
),
|
||||
)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph = self._init_graph(
|
||||
|
||||
@ -52,20 +52,24 @@ class WorkflowAppGenerateResponseConverter(
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
match sub_stream_response:
|
||||
case PingStreamResponse():
|
||||
yield "ping"
|
||||
continue
|
||||
case ErrorStreamResponse():
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case _:
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -81,20 +85,28 @@ class WorkflowAppGenerateResponseConverter(
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
match sub_stream_response:
|
||||
case PingStreamResponse():
|
||||
yield "ping"
|
||||
continue
|
||||
case ErrorStreamResponse():
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case NodeStartStreamResponse() | NodeFinishStreamResponse():
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
case _:
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@ -145,50 +145,51 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs or {},
|
||||
error=None,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
),
|
||||
)
|
||||
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
return WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
continue
|
||||
match stream_response:
|
||||
case ErrorStreamResponse():
|
||||
raise stream_response.err
|
||||
case HumanInputRequiredResponse():
|
||||
human_input_responses.append(stream_response)
|
||||
case WorkflowPauseStreamResponse():
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs or {},
|
||||
error=None,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
),
|
||||
)
|
||||
case WorkflowFinishStreamResponse():
|
||||
return WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at)
|
||||
if stream_response.data.finished_at
|
||||
else None,
|
||||
),
|
||||
)
|
||||
case _:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
@ -399,279 +399,281 @@ class WorkflowBasedAppRunner:
|
||||
:param workflow_entry: workflow entry
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, GraphRunPausedEvent):
|
||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
||||
paused_nodes = runtime_state.get_paused_nodes()
|
||||
self._enqueue_human_input_notifications(event.reasons)
|
||||
self._publish_event(
|
||||
QueueWorkflowPausedEvent(
|
||||
reasons=event.reasons,
|
||||
outputs=event.outputs,
|
||||
paused_nodes=paused_nodes,
|
||||
match event:
|
||||
case GraphRunStartedEvent():
|
||||
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
|
||||
case GraphRunSucceededEvent():
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
case GraphRunPartialSucceededEvent():
|
||||
self._publish_event(
|
||||
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormFilledEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
submitted_data=event.submitted_data,
|
||||
case GraphRunFailedEvent():
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormTimeoutEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
case GraphRunAbortedEvent():
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
case GraphRunPausedEvent():
|
||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
||||
paused_nodes = runtime_state.get_paused_nodes()
|
||||
self._enqueue_human_input_notifications(event.reasons)
|
||||
self._publish_event(
|
||||
QueueWorkflowPausedEvent(
|
||||
reasons=event.reasons,
|
||||
outputs=event.outputs,
|
||||
paused_nodes=paused_nodes,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=inputs,
|
||||
outputs=node_run_result.outputs,
|
||||
)
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
case NodeRunHumanInputFormFilledEvent():
|
||||
self._publish_event(
|
||||
QueueHumanInputFormFilledEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
)
|
||||
case NodeRunHumanInputFormTimeoutEvent():
|
||||
self._publish_event(
|
||||
QueueHumanInputFormTimeoutEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
)
|
||||
case NodeRunRetryEvent():
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
outputs=node_run_result.outputs,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=self._build_agent_strategy_info(event),
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=inputs,
|
||||
outputs=node_run_result.outputs,
|
||||
)
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
case NodeRunStartedEvent():
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=self._build_agent_strategy_info(event),
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
case NodeRunSucceededEvent():
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
outputs=node_run_result.outputs,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=event.node_run_result.inputs,
|
||||
outputs=event.node_run_result.outputs,
|
||||
)
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
case NodeRunFailedEvent():
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
outputs=event.node_run_result.outputs,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=event.node_run_result.inputs,
|
||||
outputs=event.node_run_result.outputs,
|
||||
)
|
||||
self._publish_event(
|
||||
QueueNodeExceptionEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
case NodeRunExceptionEvent():
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
outputs=event.node_run_result.outputs,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
self._publish_event(
|
||||
QueueNodeExceptionEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=[
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
],
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
case NodeRunStreamChunkEvent():
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
case NodeRunRetrieverResourceEvent():
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=[
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
],
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
case NodeRunAgentLogEvent():
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
case NodeRunIterationStartedEvent():
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
case NodeRunIterationNextEvent():
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
case NodeRunIterationSucceededEvent() | NodeRunIterationFailedEvent():
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
case NodeRunLoopStartedEvent():
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
case NodeRunLoopNextEvent():
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
)
|
||||
)
|
||||
case NodeRunLoopSucceededEvent() | NodeRunLoopFailedEvent():
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
|
||||
for reason in reasons:
|
||||
|
||||
@ -11,7 +11,6 @@ from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.variables.segments import Segment
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
@ -509,10 +508,6 @@ class QueueHumanInputFormFilledEvent(AppQueueEvent):
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
# Keep the field name aligned with Graphon so the app-layer bridge does not
|
||||
# need to translate between two equivalent payload names.
|
||||
submitted_data: Mapping[str, Segment] | None = None
|
||||
|
||||
|
||||
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -342,8 +342,6 @@ class HumanInputFormFilledResponse(StreamResponse):
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
submitted_data: Mapping[str, Any] | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from .controller import DatabaseFileAccessController
|
||||
from .protocols import FileAccessControllerProtocol
|
||||
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
|
||||
from .scope import (
|
||||
FileAccessScope,
|
||||
bind_file_access_scope,
|
||||
get_current_file_access_scope,
|
||||
grant_retriever_segment_access,
|
||||
grant_upload_file_access,
|
||||
is_retriever_segment_access_granted,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DatabaseFileAccessController",
|
||||
@ -8,4 +15,7 @@ __all__ = [
|
||||
"FileAccessScope",
|
||||
"bind_file_access_scope",
|
||||
"get_current_file_access_scope",
|
||||
"grant_retriever_segment_access",
|
||||
"grant_upload_file_access",
|
||||
"is_retriever_segment_access_granted",
|
||||
]
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
@ -18,7 +18,8 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
|
||||
Tenant scoping remains mandatory. When the current execution belongs to an
|
||||
end user, the lookup is additionally constrained to that end user's file
|
||||
ownership markers.
|
||||
ownership markers, plus upload files explicitly granted by the current
|
||||
execution context.
|
||||
"""
|
||||
|
||||
_scope_getter: Callable[[], FileAccessScope | None]
|
||||
@ -47,10 +48,19 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
if not resolved_scope.requires_user_ownership:
|
||||
return scoped_stmt
|
||||
|
||||
return scoped_stmt.where(
|
||||
user_owned_filter = and_(
|
||||
UploadFile.created_by_role == CreatorUserRole.END_USER,
|
||||
UploadFile.created_by == resolved_scope.user_id,
|
||||
)
|
||||
if not resolved_scope.granted_upload_file_ids:
|
||||
return scoped_stmt.where(user_owned_filter)
|
||||
|
||||
return scoped_stmt.where(
|
||||
or_(
|
||||
user_owned_filter,
|
||||
UploadFile.id.in_(resolved_scope.granted_upload_file_ids),
|
||||
)
|
||||
)
|
||||
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from collections.abc import Generator, Iterable
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
|
||||
@ -15,12 +15,23 @@ _current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileAccessScope:
|
||||
"""Request-scoped ownership context used by workflow-layer file lookups."""
|
||||
"""Request-scoped ownership context used by workflow-layer file lookups.
|
||||
|
||||
``granted_upload_file_ids`` is execution-local: callers may add upload files
|
||||
that were returned by trusted retrieval paths without changing persistent
|
||||
ownership markers.
|
||||
|
||||
``granted_retriever_segment_ids`` gates lazy attachment loading by segment
|
||||
ID, so user-provided context cannot make a later LLM node load arbitrary
|
||||
same-tenant knowledge attachments.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
granted_upload_file_ids: frozenset[str] = field(default_factory=frozenset)
|
||||
granted_retriever_segment_ids: frozenset[str] = field(default_factory=frozenset)
|
||||
|
||||
@property
|
||||
def requires_user_ownership(self) -> bool:
|
||||
@ -31,8 +42,49 @@ def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
return _current_file_access_scope.get()
|
||||
|
||||
|
||||
def grant_upload_file_access(upload_file_ids: Iterable[str]) -> None:
|
||||
scope = _current_file_access_scope.get()
|
||||
if scope is None:
|
||||
return
|
||||
|
||||
granted_upload_file_ids = frozenset(str(file_id) for file_id in upload_file_ids if file_id)
|
||||
if not granted_upload_file_ids:
|
||||
return
|
||||
|
||||
_current_file_access_scope.set(
|
||||
replace(
|
||||
scope,
|
||||
granted_upload_file_ids=scope.granted_upload_file_ids | granted_upload_file_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def grant_retriever_segment_access(segment_ids: Iterable[str]) -> None:
|
||||
scope = _current_file_access_scope.get()
|
||||
if scope is None:
|
||||
return
|
||||
|
||||
granted_segment_ids = frozenset(str(segment_id) for segment_id in segment_ids if segment_id)
|
||||
if not granted_segment_ids:
|
||||
return
|
||||
|
||||
_current_file_access_scope.set(
|
||||
replace(
|
||||
scope,
|
||||
granted_retriever_segment_ids=scope.granted_retriever_segment_ids | granted_segment_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_retriever_segment_access_granted(segment_id: str) -> bool:
|
||||
scope = _current_file_access_scope.get()
|
||||
if scope is None or not scope.requires_user_ownership:
|
||||
return True
|
||||
return str(segment_id) in scope.granted_retriever_segment_ids
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]:
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -140,42 +140,43 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
match stream_response:
|
||||
case ErrorStreamResponse():
|
||||
raise stream_response.err
|
||||
case MessageEndStreamResponse():
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
return response
|
||||
case _:
|
||||
continue
|
||||
|
||||
raise RuntimeError("queue listening stopped unexpectedly.")
|
||||
|
||||
@ -265,104 +266,107 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
match event:
|
||||
case QueueErrorEvent():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
case QueueStopEvent() | QueueMessageEndEvent():
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.get_text_content()
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.get_text_content()
|
||||
)
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought_response = self._agent_thought_to_stream_response(event)
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
delta_text = ""
|
||||
for content in chunk.delta.message.content:
|
||||
logger.debug(
|
||||
"The content type %s in LLM chunk delta message content.: %r", type(content), content
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
delta_text += content.data
|
||||
elif isinstance(content, str):
|
||||
delta_text += content # failback to str
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported content type %s in LLM chunk delta message content.: %r",
|
||||
type(content),
|
||||
content,
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
case QueueRetrieverResourcesEvent():
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
case QueueAnnotationReplyEvent():
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
case QueueAgentThoughtEvent():
|
||||
agent_thought_response = self._agent_thought_to_stream_response(event)
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
case QueueMessageFileEvent():
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
case QueueLLMChunkEvent() | QueueAgentMessageEvent():
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
delta_text = ""
|
||||
for content in chunk.delta.message.content:
|
||||
logger.debug(
|
||||
"The content type %s in LLM chunk delta message content.: %r", type(content), content
|
||||
)
|
||||
continue
|
||||
match content:
|
||||
case TextPromptMessageContent():
|
||||
delta_text += content.data
|
||||
case str():
|
||||
delta_text += content # failback to str
|
||||
case _:
|
||||
logger.warning(
|
||||
"Unsupported content type %s in LLM chunk delta message content.: %r",
|
||||
type(content),
|
||||
content,
|
||||
)
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
match event:
|
||||
case QueueLLMChunkEvent():
|
||||
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
|
||||
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
|
||||
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
|
||||
message_id=self._message_id
|
||||
)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=self._precomputed_event_type,
|
||||
)
|
||||
case _:
|
||||
yield self._agent_message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
case QueueMessageReplaceEvent():
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
case QueuePingEvent():
|
||||
yield self.ping_stream_response()
|
||||
case _:
|
||||
continue
|
||||
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
|
||||
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
|
||||
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
|
||||
message_id=self._message_id
|
||||
)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=self._precomputed_event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self.ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
|
||||
@ -95,13 +95,14 @@ class AppGeneratorTTSPublisher:
|
||||
message_content = message.event.chunk.delta.message.content
|
||||
if not message_content:
|
||||
continue
|
||||
if isinstance(message_content, str):
|
||||
self.msg_text += message_content
|
||||
elif isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
match message_content:
|
||||
case str():
|
||||
self.msg_text += message_content
|
||||
case list():
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
@ -19,8 +19,6 @@ class HumanInputFormDefinition(BaseModel):
|
||||
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
|
||||
actions: Sequence[UserActionConfig] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
|
||||
# `form_token` is `None` if the corresponding form has been submitted.
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
@ -31,31 +29,16 @@ class HumanInputFormSubmissionData(BaseModel):
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
|
||||
# deprecate: the rendered_content is deprecated and only for historical reasons.
|
||||
rendered_content: str
|
||||
|
||||
# The identifier of action user has chosen.
|
||||
action_id: str
|
||||
# The button text of the action user has chosen.
|
||||
action_text: str
|
||||
|
||||
# submitted_data records the submitted form data.
|
||||
# Keys correspond to `output_variable_name` of HumanInput inputs.
|
||||
# Values are serialized JSON forms of runtime values, including file dictionaries.
|
||||
#
|
||||
# For form submitted before this field is introduced, this field is populated from
|
||||
# the stored submission data.
|
||||
submitted_data: Mapping[str, JsonValue] | None = None
|
||||
|
||||
|
||||
class HumanInputContent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
workflow_run_id: str
|
||||
submitted: bool
|
||||
# Both the form_defintion and the form_submission_data are present in
|
||||
# HumanInputContent. For historical records, the
|
||||
form_definition: HumanInputFormDefinition | None = None
|
||||
form_submission_data: HumanInputFormSubmissionData | None = None
|
||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||
|
||||
@ -171,9 +171,9 @@ class ProviderConfiguration(BaseModel):
|
||||
current_credential_id = self.custom_configuration.provider.current_credential_id
|
||||
|
||||
if current_credential_id:
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
from core.helper.credential_utils import runtime_check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
runtime_check_credential_policy_compliance(
|
||||
credential_id=current_credential_id,
|
||||
provider=self.provider.provider,
|
||||
credential_type=PluginCredentialType.MODEL,
|
||||
@ -182,9 +182,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# no current credential id, check all available credentials
|
||||
if self.custom_configuration.provider:
|
||||
for credential_configuration in self.custom_configuration.provider.available_credentials:
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
from core.helper.credential_utils import runtime_check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
runtime_check_credential_policy_compliance(
|
||||
credential_id=credential_configuration.credential_id,
|
||||
provider=self.provider.provider,
|
||||
credential_type=PluginCredentialType.MODEL,
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
class LLMError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: str | None = None
|
||||
description: str = ""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
def __init__(self, description: str = ""):
|
||||
self.description = description
|
||||
|
||||
|
||||
|
||||
@ -1,279 +0,0 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
CustomizedMetrics,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
NodeInfo,
|
||||
)
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationInstance(ABC):
|
||||
"""Abstract base class for evaluation framework adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate LLM outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate retrieval quality using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate agent outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
"""Return the list of supported metric names for a given evaluation category."""
|
||||
...
|
||||
|
||||
def evaluate_with_customized_workflow(
|
||||
self,
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
|
||||
customized_metrics: CustomizedMetrics,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using a published workflow as the evaluator.
|
||||
|
||||
The evaluator workflow's output variables are treated as metrics:
|
||||
each output variable name becomes a metric name, and its value
|
||||
becomes the score.
|
||||
|
||||
Args:
|
||||
node_run_result_mapping_list: One mapping per test-data item,
|
||||
where each mapping is ``{node_id: NodeRunResult}`` from the
|
||||
target execution.
|
||||
customized_metrics: Contains ``evaluation_workflow_id`` (the
|
||||
published evaluator workflow) and ``input_fields`` (value
|
||||
sources for the evaluator's input variables).
|
||||
tenant_id: Tenant scope.
|
||||
|
||||
Returns:
|
||||
A list of ``EvaluationItemResult`` with metrics extracted from
|
||||
the evaluator workflow's output variables.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from models.engine import db
|
||||
from models.model import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
workflow_id = customized_metrics.evaluation_workflow_id
|
||||
if not workflow_id:
|
||||
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
|
||||
|
||||
# Load the evaluator workflow resources using a dedicated session
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
|
||||
service_account = get_service_account_for_app(session, workflow_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
published_workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not published_workflow:
|
||||
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
|
||||
|
||||
eval_results: list[EvaluationItemResult] = []
|
||||
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
|
||||
try:
|
||||
workflow_inputs = self._build_workflow_inputs(
|
||||
customized_metrics.input_fields,
|
||||
node_run_result_mapping,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
response: Mapping[str, Any] = generator.generate(
|
||||
app_model=app,
|
||||
workflow=published_workflow,
|
||||
user=service_account,
|
||||
args={"inputs": workflow_inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
metrics = self._extract_workflow_metrics(response, workflow_id)
|
||||
eval_results.append(
|
||||
EvaluationItemResult(
|
||||
index=idx,
|
||||
metrics=metrics,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Customized evaluator failed for item %d with workflow %s",
|
||||
idx,
|
||||
workflow_id,
|
||||
)
|
||||
eval_results.append(EvaluationItemResult(index=idx))
|
||||
|
||||
return eval_results
|
||||
|
||||
@staticmethod
|
||||
def _build_workflow_inputs(
|
||||
input_fields: dict[str, Any],
|
||||
node_run_result_mapping: dict[str, NodeRunResult],
|
||||
) -> dict[str, Any]:
|
||||
"""Build customized workflow inputs by resolving value sources.
|
||||
|
||||
Each entry in ``input_fields`` maps a workflow input variable name
|
||||
to its value source, which can be:
|
||||
|
||||
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
|
||||
- **Expression**: a string containing one or more
|
||||
``{{#node_id.output_key#}}`` selectors (same format as
|
||||
``VariableTemplateParser``) resolved from
|
||||
``node_run_result_mapping``.
|
||||
|
||||
"""
|
||||
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
|
||||
|
||||
workflow_inputs: dict[str, Any] = {}
|
||||
|
||||
for field_name, value_source in input_fields.items():
|
||||
if not isinstance(value_source, str):
|
||||
# Non-string values (numbers, bools, dicts) are used directly.
|
||||
workflow_inputs[field_name] = value_source
|
||||
continue
|
||||
|
||||
# Check if the entire value is a single expression.
|
||||
full_match = VARIABLE_REGEX.fullmatch(value_source)
|
||||
if full_match:
|
||||
workflow_inputs[field_name] = resolve_variable_selector(
|
||||
full_match.group(1),
|
||||
node_run_result_mapping,
|
||||
)
|
||||
elif VARIABLE_REGEX.search(value_source):
|
||||
# Mixed template: interpolate all expressions as strings.
|
||||
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
|
||||
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
|
||||
value_source,
|
||||
)
|
||||
else:
|
||||
# Plain constant — no expression markers.
|
||||
workflow_inputs[field_name] = value_source
|
||||
|
||||
return workflow_inputs
|
||||
|
||||
@staticmethod
|
||||
def _extract_workflow_metrics(
|
||||
response: Mapping[str, object],
|
||||
evaluation_workflow_id: str,
|
||||
) -> list[EvaluationMetric]:
|
||||
"""Extract evaluation metrics from workflow output variables.
|
||||
|
||||
Each metric's ``node_info`` is set with *evaluation_workflow_id* as
|
||||
the ``node_id``, so that judgment conditions can reference customized
|
||||
metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``.
|
||||
"""
|
||||
metrics: list[EvaluationMetric] = []
|
||||
node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized")
|
||||
|
||||
data = response.get("data")
|
||||
if not isinstance(data, Mapping):
|
||||
logger.warning("Unexpected workflow response format: missing 'data' dict")
|
||||
return metrics
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if not isinstance(outputs, dict):
|
||||
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
|
||||
return metrics
|
||||
|
||||
for key, raw_value in outputs.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info))
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def resolve_variable_selector(
|
||||
selector_raw: str,
|
||||
node_run_result_mapping: dict[str, NodeRunResult],
|
||||
) -> object:
|
||||
"""
|
||||
Resolve a ``#node_id.output_key#`` selector against node run results.
|
||||
"""
|
||||
#
|
||||
cleaned = selector_raw.strip("#")
|
||||
parts = cleaned.split(".")
|
||||
|
||||
if len(parts) < 2:
|
||||
logger.warning(
|
||||
"Selector '%s' must have at least node_id.output_key",
|
||||
selector_raw,
|
||||
)
|
||||
return ""
|
||||
|
||||
node_id = parts[0]
|
||||
output_path = parts[1:]
|
||||
|
||||
node_result = node_run_result_mapping.get(node_id)
|
||||
if not node_result or not node_result.outputs:
|
||||
logger.warning(
|
||||
"Selector '%s': node '%s' not found or has no outputs",
|
||||
selector_raw,
|
||||
node_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
# Traverse the output path to support nested keys.
|
||||
current: object = node_result.outputs
|
||||
for key in output_path:
|
||||
if isinstance(current, Mapping):
|
||||
next_val = current.get(key)
|
||||
if next_val is None:
|
||||
logger.warning(
|
||||
"Selector '%s': key '%s' not found in node '%s' outputs",
|
||||
selector_raw,
|
||||
key,
|
||||
node_id,
|
||||
)
|
||||
return ""
|
||||
current = next_val
|
||||
else:
|
||||
logger.warning(
|
||||
"Selector '%s': cannot traverse into non-dict value at key '%s'",
|
||||
selector_raw,
|
||||
key,
|
||||
)
|
||||
return ""
|
||||
|
||||
return current if current is not None else ""
|
||||
@ -1,27 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EvaluationFrameworkEnum(StrEnum):
|
||||
RAGAS = "ragas"
|
||||
DEEPEVAL = "deepeval"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class BaseEvaluationConfig(BaseModel):
|
||||
"""Base configuration for evaluation frameworks."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RagasConfig(BaseEvaluationConfig):
|
||||
"""RAGAS-specific configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepEvalConfig(BaseEvaluationConfig):
|
||||
"""DeepEval-specific configuration."""
|
||||
|
||||
pass
|
||||
@ -1,280 +0,0 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
|
||||
|
||||
|
||||
class EvaluationCategory(StrEnum):
|
||||
LLM = "llm"
|
||||
RETRIEVAL = "knowledge_retrieval"
|
||||
AGENT = "agent"
|
||||
WORKFLOW = "workflow"
|
||||
SNIPPET = "snippet"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
|
||||
class EvaluationMetricName(StrEnum):
|
||||
"""Canonical metric names shared across all evaluation frameworks.
|
||||
|
||||
Each framework maps these names to its own internal implementation.
|
||||
A framework that does not support a given metric should log a warning
|
||||
and skip it rather than raising an error.
|
||||
|
||||
── LLM / general text-quality metrics ──────────────────────────────────
|
||||
FAITHFULNESS
|
||||
Measures whether every claim in the model's response is grounded in
|
||||
the provided retrieved context. A high score means the answer
|
||||
contains no hallucinated content — each statement can be traced back
|
||||
to a passage in the context.
|
||||
Required fields: user_input, response, retrieved_contexts.
|
||||
|
||||
ANSWER_RELEVANCY
|
||||
Measures how well the model's response addresses the user's question.
|
||||
A high score means the answer stays on-topic; a low score indicates
|
||||
irrelevant content or a failure to answer the actual question.
|
||||
Required fields: user_input, response.
|
||||
|
||||
ANSWER_CORRECTNESS
|
||||
Measures the factual accuracy and completeness of the model's answer
|
||||
relative to a ground-truth reference. It combines semantic similarity
|
||||
with key-fact coverage, so both meaning and content matter.
|
||||
Required fields: user_input, response, reference (expected_output).
|
||||
|
||||
SEMANTIC_SIMILARITY
|
||||
Measures the cosine similarity between the model's response and the
|
||||
reference answer in an embedding space. It evaluates whether the two
|
||||
texts convey the same meaning, independent of factual correctness.
|
||||
Required fields: response, reference (expected_output).
|
||||
|
||||
── Retrieval-quality metrics ────────────────────────────────────────────
|
||||
CONTEXT_PRECISION
|
||||
Measures the proportion of retrieved context chunks that are actually
|
||||
relevant to the question (precision). A high score means the retrieval
|
||||
pipeline returns little noise.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RECALL
|
||||
Measures the proportion of ground-truth information that is covered by
|
||||
the retrieved context chunks (recall). A high score means the retrieval
|
||||
pipeline does not miss important supporting evidence.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RELEVANCE
|
||||
Measures how relevant each individual retrieved chunk is to the query.
|
||||
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
|
||||
than against a reference answer.
|
||||
Required fields: user_input, retrieved_contexts.
|
||||
|
||||
── Agent-quality metrics ────────────────────────────────────────────────
|
||||
TOOL_CORRECTNESS
|
||||
Measures the correctness of the tool calls made by the agent during
|
||||
task execution — both the choice of tool and the arguments passed.
|
||||
A high score means the agent's tool-use strategy matches the expected
|
||||
behavior.
|
||||
Required fields: actual tool calls vs. expected tool calls.
|
||||
|
||||
TASK_COMPLETION
|
||||
Measures whether the agent ultimately achieves the user's stated goal.
|
||||
It evaluates the reasoning chain, intermediate steps, and final output
|
||||
holistically; a high score means the task was fully accomplished.
|
||||
Required fields: user_input, actual_output.
|
||||
"""
|
||||
|
||||
# LLM / general text-quality metrics
|
||||
FAITHFULNESS = "faithfulness"
|
||||
ANSWER_RELEVANCY = "answer_relevancy"
|
||||
ANSWER_CORRECTNESS = "answer_correctness"
|
||||
SEMANTIC_SIMILARITY = "semantic_similarity"
|
||||
|
||||
# Retrieval-quality metrics
|
||||
CONTEXT_PRECISION = "context_precision"
|
||||
CONTEXT_RECALL = "context_recall"
|
||||
CONTEXT_RELEVANCE = "context_relevance"
|
||||
|
||||
# Agent-quality metrics
|
||||
TOOL_CORRECTNESS = "tool_correctness"
|
||||
TASK_COMPLETION = "task_completion"
|
||||
|
||||
|
||||
# Per-category canonical metric lists used by get_supported_metrics().
|
||||
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
|
||||
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
|
||||
]
|
||||
|
||||
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
|
||||
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
|
||||
]
|
||||
|
||||
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
|
||||
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
|
||||
]
|
||||
|
||||
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY,
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS,
|
||||
]
|
||||
|
||||
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
|
||||
**{m.value: "llm" for m in LLM_METRIC_NAMES},
|
||||
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
|
||||
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
|
||||
}
|
||||
|
||||
METRIC_VALUE_TYPE_MAPPING: dict[str, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "number",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "number",
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: "number",
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: "number",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "number",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "number",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "number",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "number",
|
||||
EvaluationMetricName.TASK_COMPLETION: "number",
|
||||
}
|
||||
|
||||
|
||||
class NodeInfo(BaseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class EvaluationMetric(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
details: dict[str, Any] = Field(default_factory=dict)
|
||||
node_info: NodeInfo | None = None
|
||||
|
||||
|
||||
class EvaluationItemInput(BaseModel):
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
output: str
|
||||
expected_output: str | None = None
|
||||
context: list[str] | None = None
|
||||
|
||||
|
||||
class EvaluationDatasetInput(BaseModel):
|
||||
"""Parsed dataset row used throughout evaluation execution.
|
||||
|
||||
``expected_output`` keeps backward compatibility with the original
|
||||
single-reference template. When users upload node-specific reference
|
||||
columns such as ``LLM 1 : expected_output``, they are stored in
|
||||
``expected_outputs`` and resolved by node title at execution time.
|
||||
"""
|
||||
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
expected_output: str | None = None
|
||||
expected_outputs: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
def get_expected_output_for_node(self, node_title: str | None) -> str | None:
|
||||
"""Return the best matching reference answer for the given node title."""
|
||||
if node_title:
|
||||
if node_title in self.expected_outputs:
|
||||
return self.expected_outputs[node_title]
|
||||
|
||||
if self.expected_output is not None:
|
||||
return self.expected_output
|
||||
|
||||
if len(self.expected_outputs) == 1:
|
||||
return next(iter(self.expected_outputs.values()))
|
||||
|
||||
return None
|
||||
|
||||
def serialize_expected_output(self) -> str | None:
|
||||
"""Serialize references for persistence and API responses.
|
||||
|
||||
Single-reference datasets stay unchanged, while multi-node references
|
||||
are stored as JSON so history/detail APIs can still expose the full
|
||||
uploaded payload without changing the database schema.
|
||||
"""
|
||||
if self.expected_output is not None and not self.expected_outputs:
|
||||
return self.expected_output
|
||||
|
||||
if not self.expected_outputs:
|
||||
return None
|
||||
|
||||
serialized_expected_outputs = dict(self.expected_outputs)
|
||||
if self.expected_output is not None:
|
||||
serialized_expected_outputs = {"expected_output": self.expected_output, **serialized_expected_outputs}
|
||||
|
||||
return json.dumps(serialized_expected_outputs, ensure_ascii=False, sort_keys=True)
|
||||
|
||||
def iter_expected_output_columns(self) -> list[tuple[str, str]]:
|
||||
"""Return uploaded expected-output columns in display order."""
|
||||
columns: list[tuple[str, str]] = []
|
||||
if self.expected_output is not None:
|
||||
columns.append(("expected_output", self.expected_output))
|
||||
|
||||
for node_title, value in self.expected_outputs.items():
|
||||
columns.append((f"{node_title} : expected_output", value))
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
class EvaluationItemResult(BaseModel):
|
||||
index: int
|
||||
actual_output: str | None = None
|
||||
metrics: list[EvaluationMetric] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class DefaultMetric(BaseModel):
|
||||
metric: str
|
||||
value_type: str = ""
|
||||
node_info_list: list[NodeInfo]
|
||||
|
||||
|
||||
class CustomizedMetricOutputField(BaseModel):
|
||||
variable: str
|
||||
value_type: str
|
||||
|
||||
|
||||
class CustomizedMetrics(BaseModel):
|
||||
evaluation_workflow_id: str
|
||||
input_fields: dict[str, Any]
|
||||
output_fields: list[CustomizedMetricOutputField]
|
||||
|
||||
|
||||
class EvaluationConfigData(BaseModel):
|
||||
"""Structured data for saving evaluation configuration."""
|
||||
|
||||
evaluation_model: str = ""
|
||||
evaluation_model_provider: str = ""
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
|
||||
|
||||
class EvaluationRunRequest(EvaluationConfigData):
|
||||
"""Request body for starting an evaluation run."""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class EvaluationRunData(BaseModel):
|
||||
"""Serializable data for Celery task."""
|
||||
|
||||
evaluation_run_id: str
|
||||
tenant_id: str
|
||||
target_type: str
|
||||
target_id: str
|
||||
evaluation_model_provider: str
|
||||
evaluation_model: str
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
input_list: list[EvaluationDatasetInput]
|
||||
@ -1,118 +0,0 @@
|
||||
"""Judgment condition entities for evaluation metric assessment.
|
||||
|
||||
Condition structure mirrors the workflow if-else ``Condition`` model from
|
||||
``graphon.utils.condition.entities``. The left-hand side uses
|
||||
``variable_selector`` — a two-element list ``[node_id, metric_name]`` — to
|
||||
uniquely identify an evaluation metric (different nodes may produce metrics
|
||||
with the same name).
|
||||
|
||||
Operators reuse ``SupportedComparisonOperator`` from the workflow engine so
|
||||
that type semantics stay consistent across the platform.
|
||||
|
||||
Typical usage::
|
||||
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["node_abc", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
|
||||
COMPARISON_OPERATOR_ALIASES: dict[str, str] = {
|
||||
"==": "=",
|
||||
"!=": "≠",
|
||||
">=": "≥",
|
||||
"<=": "≤",
|
||||
"is null": "null",
|
||||
"is not null": "not null",
|
||||
}
|
||||
|
||||
|
||||
class JudgmentCondition(BaseModel):
|
||||
"""A single judgment condition that checks one metric value.
|
||||
|
||||
Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand
|
||||
side being a metric selector instead of a workflow variable selector.
|
||||
|
||||
Attributes:
|
||||
variable_selector: ``[node_id, metric_name]`` identifying the metric.
|
||||
comparison_operator: Reuses workflow's ``SupportedComparisonOperator``.
|
||||
value: The comparison target (right side). For unary operators such
|
||||
as ``empty`` or ``null`` this can be ``None``.
|
||||
"""
|
||||
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
|
||||
@field_validator("comparison_operator", mode="before")
|
||||
@classmethod
|
||||
def normalize_comparison_operator(cls, value: Any) -> Any:
|
||||
"""Accept common ASCII/API aliases for workflow comparison operators."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
normalized_value = value.strip().lower()
|
||||
alias = COMPARISON_OPERATOR_ALIASES.get(normalized_value)
|
||||
if alias is not None:
|
||||
return alias
|
||||
return value.strip()
|
||||
|
||||
|
||||
class JudgmentConfig(BaseModel):
|
||||
"""A group of judgment conditions combined with a logical operator.
|
||||
|
||||
Attributes:
|
||||
logical_operator: How to combine condition results — "and" requires
|
||||
all conditions to pass, "or" requires at least one.
|
||||
conditions: The list of individual conditions to evaluate.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[JudgmentCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class JudgmentConditionResult(BaseModel):
|
||||
"""Result of evaluating a single judgment condition.
|
||||
|
||||
Attributes:
|
||||
variable_selector: ``[node_id, metric_name]`` that was checked.
|
||||
comparison_operator: The operator that was applied.
|
||||
expected_value: The resolved comparison value.
|
||||
actual_value: The actual metric value that was evaluated.
|
||||
passed: Whether this individual condition passed.
|
||||
error: Error message if the condition evaluation failed.
|
||||
"""
|
||||
|
||||
variable_selector: list[str]
|
||||
comparison_operator: str
|
||||
expected_value: Any = None
|
||||
actual_value: Any = None
|
||||
passed: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class JudgmentResult(BaseModel):
|
||||
"""Overall result of evaluating all judgment conditions for one item.
|
||||
|
||||
Attributes:
|
||||
passed: Whether the overall judgment passed (based on logical_operator).
|
||||
logical_operator: The logical operator used to combine conditions.
|
||||
condition_results: Detailed result for each individual condition.
|
||||
"""
|
||||
|
||||
passed: bool = False
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)
|
||||
@ -1,61 +0,0 @@
|
||||
import collections
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
|
||||
|
||||
def __getitem__(self, framework: str) -> dict[str, Any]:
|
||||
match framework:
|
||||
case EvaluationFrameworkEnum.RAGAS:
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
|
||||
|
||||
return {
|
||||
"config_class": RagasConfig,
|
||||
"evaluator_class": RagasEvaluator,
|
||||
}
|
||||
case EvaluationFrameworkEnum.DEEPEVAL:
|
||||
raise NotImplementedError("DeepEval adapter is not yet implemented.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown evaluation framework: {framework}")
|
||||
|
||||
|
||||
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
|
||||
|
||||
|
||||
class EvaluationManager:
|
||||
"""Factory for evaluation instances based on global configuration."""
|
||||
|
||||
@staticmethod
|
||||
def get_evaluation_instance() -> BaseEvaluationInstance | None:
|
||||
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
|
||||
framework = dify_config.EVALUATION_FRAMEWORK
|
||||
if not framework or framework == EvaluationFrameworkEnum.NONE:
|
||||
return None
|
||||
|
||||
try:
|
||||
config_map = evaluation_framework_config_map[framework]
|
||||
evaluator_class = config_map["evaluator_class"]
|
||||
config_class = config_map["config_class"]
|
||||
config = config_class()
|
||||
return evaluator_class(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create evaluation instance for framework: %s", framework)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
|
||||
"""Return supported metrics for the current framework and given category."""
|
||||
instance = EvaluationManager.get_evaluation_instance()
|
||||
if instance is None:
|
||||
return []
|
||||
return instance.get_supported_metrics(category)
|
||||
@ -1,299 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import DeepEvalConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
AGENT_METRIC_NAMES,
|
||||
LLM_METRIC_NAMES,
|
||||
RETRIEVAL_METRIC_NAMES,
|
||||
WORKFLOW_METRIC_NAMES,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
EvaluationMetricName,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
|
||||
# deepeval metric field requirements (LLMTestCase fields):
|
||||
# - faithfulness: input, actual_output, retrieval_context
|
||||
# - answer_relevancy: input, actual_output
|
||||
# - context_precision: input, actual_output, expected_output, retrieval_context
|
||||
# - context_recall: input, actual_output, expected_output, retrieval_context
|
||||
# - context_relevance: input, actual_output, retrieval_context
|
||||
# - tool_correctness: input, actual_output, expected_tools
|
||||
# - task_completion: input, actual_output
|
||||
# Metrics not listed here are unsupported by deepeval and will be skipped.
|
||||
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
|
||||
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
|
||||
}
|
||||
|
||||
|
||||
class DeepEvalEvaluator(BaseEvaluationInstance):
|
||||
"""DeepEval framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: DeepEvalConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
candidates = LLM_METRIC_NAMES
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
candidates = RETRIEVAL_METRIC_NAMES
|
||||
case EvaluationCategory.AGENT:
|
||||
candidates = AGENT_METRIC_NAMES
|
||||
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
|
||||
candidates = WORKFLOW_METRIC_NAMES
|
||||
case _:
|
||||
return []
|
||||
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using DeepEval."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_deepeval(items, requested_metrics, category)
|
||||
except ImportError:
|
||||
logger.warning("DeepEval not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_deepeval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using DeepEval library.
|
||||
|
||||
Builds LLMTestCase differently per category:
|
||||
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
|
||||
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
|
||||
- Agent: input=query, actual_output=output
|
||||
"""
|
||||
metric_pairs = _build_deepeval_metrics(requested_metrics)
|
||||
if not metric_pairs:
|
||||
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index) for item in items]
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
test_case = self._build_test_case(item, category)
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for canonical_name, metric in metric_pairs:
|
||||
try:
|
||||
metric.measure(test_case)
|
||||
if metric.score is not None:
|
||||
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to compute metric %s for item %d",
|
||||
canonical_name,
|
||||
item.index,
|
||||
)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
|
||||
"""Build a deepeval LLMTestCase with the correct fields per category."""
|
||||
from deepeval.test_case import LLMTestCase
|
||||
|
||||
user_input = _format_input(item.inputs, category)
|
||||
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
# faithfulness needs: input, actual_output, retrieval_context
|
||||
# answer_relevancy needs: input, actual_output
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output,
|
||||
expected_output=item.expected_output or None,
|
||||
retrieval_context=item.context or None,
|
||||
)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output or "",
|
||||
expected_output=item.expected_output or "",
|
||||
retrieval_context=item.context or [],
|
||||
)
|
||||
case _:
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output,
|
||||
)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when DeepEval is not available."""
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, m_name, item)
|
||||
metrics.append(EvaluationMetric(name=m_name, value=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
@staticmethod
|
||||
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nInput: {item.inputs}",
|
||||
f"\nOutput: {item.output}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
|
||||
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
|
||||
"""Extract the user-facing input string from the inputs dict."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
return str(inputs.get("prompt", ""))
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return str(inputs.get("query", ""))
|
||||
case _:
|
||||
return str(next(iter(inputs.values()), "")) if inputs else ""
|
||||
|
||||
|
||||
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
|
||||
"""Build DeepEval metric instances from canonical metric names.
|
||||
|
||||
Returns a list of (canonical_name, metric_instance) pairs so that callers
|
||||
can record the canonical name rather than the framework-internal class name.
|
||||
"""
|
||||
try:
|
||||
from deepeval.metrics import (
|
||||
AnswerRelevancyMetric,
|
||||
ContextualPrecisionMetric,
|
||||
ContextualRecallMetric,
|
||||
ContextualRelevancyMetric,
|
||||
FaithfulnessMetric,
|
||||
TaskCompletionMetric,
|
||||
ToolCorrectnessMetric,
|
||||
)
|
||||
|
||||
# Maps canonical name → deepeval metric class
|
||||
deepeval_class_map: dict[str, Any] = {
|
||||
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
|
||||
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
|
||||
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
|
||||
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
|
||||
}
|
||||
|
||||
pairs: list[tuple[str, Any]] = []
|
||||
for name in requested_metrics:
|
||||
metric_class = deepeval_class_map.get(name)
|
||||
if metric_class:
|
||||
pairs.append((name, metric_class(threshold=0.5)))
|
||||
else:
|
||||
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
|
||||
return pairs
|
||||
except ImportError:
|
||||
logger.warning("DeepEval metrics not available")
|
||||
return []
|
||||
@ -1,345 +0,0 @@
|
||||
import logging
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
AGENT_METRIC_NAMES,
|
||||
LLM_METRIC_NAMES,
|
||||
RETRIEVAL_METRIC_NAMES,
|
||||
WORKFLOW_METRIC_NAMES,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
EvaluationMetricName,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
|
||||
# Metrics not listed here are unsupported by ragas and will be skipped.
|
||||
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
|
||||
}
|
||||
|
||||
|
||||
class RagasEvaluator(BaseEvaluationInstance):
|
||||
"""RAGAS framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: RagasConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
candidates = LLM_METRIC_NAMES
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
candidates = RETRIEVAL_METRIC_NAMES
|
||||
case EvaluationCategory.AGENT:
|
||||
candidates = AGENT_METRIC_NAMES
|
||||
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
|
||||
candidates = WORKFLOW_METRIC_NAMES
|
||||
case _:
|
||||
return []
|
||||
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using RAGAS."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
|
||||
except ImportError:
|
||||
logger.warning("RAGAS not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_ragas(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using RAGAS library.
|
||||
|
||||
Builds SingleTurnSample differently per category to match ragas requirements:
|
||||
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
|
||||
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
|
||||
- Agent: Not supported via EvaluationDataset (requires message-based API)
|
||||
"""
|
||||
from ragas import evaluate as ragas_evaluate
|
||||
from ragas.dataset_schema import EvaluationDataset
|
||||
|
||||
samples: list[Any] = []
|
||||
for item in items:
|
||||
sample = self._build_sample(item, category)
|
||||
samples.append(sample)
|
||||
|
||||
dataset = EvaluationDataset(samples=samples)
|
||||
|
||||
ragas_metrics = self._build_ragas_metrics(requested_metrics)
|
||||
if not ragas_metrics:
|
||||
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index, actual_output=item.output) for item in items]
|
||||
|
||||
try:
|
||||
result = ragas_evaluate(
|
||||
dataset=dataset,
|
||||
metrics=ragas_metrics,
|
||||
llm=model_wrapper,
|
||||
)
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
result_df = result.to_pandas()
|
||||
for i, item in enumerate(items):
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
if m_name in result_df.columns:
|
||||
score = result_df.iloc[i][m_name]
|
||||
if score is not None and not (isinstance(score, float) and score != score):
|
||||
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics, actual_output=item.output))
|
||||
return results
|
||||
except Exception:
|
||||
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
@staticmethod
|
||||
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
|
||||
"""Build a ragas SingleTurnSample with the correct fields per category.
|
||||
|
||||
ragas metric field requirements:
|
||||
- faithfulness: user_input, response, retrieved_contexts
|
||||
- answer_relevancy: user_input, response
|
||||
- answer_correctness: user_input, response, reference
|
||||
- semantic_similarity: user_input, response, reference
|
||||
- context_precision: user_input, reference, retrieved_contexts
|
||||
- context_recall: user_input, reference, retrieved_contexts
|
||||
- context_relevance: user_input, retrieved_contexts
|
||||
"""
|
||||
from ragas.dataset_schema import SingleTurnSample
|
||||
|
||||
user_input = _format_input(item.inputs, category)
|
||||
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
# response = actual LLM output, reference = expected output
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
response=item.output,
|
||||
reference=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
# context_precision/recall only need reference + retrieved_contexts
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
reference=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
case _:
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
response=item.output,
|
||||
)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when RAGAS is not available."""
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, m_name, item)
|
||||
metrics.append(EvaluationMetric(name=m_name, value=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics, actual_output=item.output))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
@staticmethod
|
||||
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nInput: {item.inputs}",
|
||||
f"\nOutput: {item.output}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
|
||||
"""Build RAGAS metric instances from canonical metric names."""
|
||||
try:
|
||||
metrics_module = _import_ragas_metrics_module()
|
||||
|
||||
# Maps canonical name → ragas metric class
|
||||
ragas_class_map: dict[str, Any] = {
|
||||
EvaluationMetricName.FAITHFULNESS: getattr(metrics_module, "Faithfulness"),
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: getattr(metrics_module, "AnswerRelevancy"),
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: getattr(metrics_module, "AnswerCorrectness"),
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: getattr(metrics_module, "SemanticSimilarity"),
|
||||
EvaluationMetricName.CONTEXT_PRECISION: getattr(metrics_module, "ContextPrecision"),
|
||||
EvaluationMetricName.CONTEXT_RECALL: getattr(metrics_module, "ContextRecall"),
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: getattr(metrics_module, "ContextRelevance"),
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: getattr(metrics_module, "ToolCallAccuracy"),
|
||||
}
|
||||
|
||||
metrics = []
|
||||
for name in requested_metrics:
|
||||
metric_class = ragas_class_map.get(name)
|
||||
if metric_class:
|
||||
if name == EvaluationMetricName.ANSWER_CORRECTNESS:
|
||||
# ragas answer_correctness blends factuality with semantic
|
||||
# similarity. The latter requires an embeddings backend,
|
||||
# which is not wired through Dify's evaluation stack yet.
|
||||
# Keep the metric usable by relying on the factuality
|
||||
# component only for now.
|
||||
metrics.append(metric_class(weights=[1.0, 0.0], embeddings=_NoopRagasEmbeddings()))
|
||||
else:
|
||||
metrics.append(metric_class())
|
||||
else:
|
||||
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
|
||||
return metrics
|
||||
except ImportError:
|
||||
logger.warning("RAGAS metrics not available")
|
||||
return []
|
||||
|
||||
|
||||
def _import_ragas_metrics_module() -> Any:
|
||||
"""Load ragas metric classes across supported ragas versions.
|
||||
|
||||
ragas 0.3.x exposes metric classes from ``ragas.metrics`` while some older
|
||||
versions used ``ragas.metrics.collections``. Support both so worker
|
||||
environments do not silently drop all metrics because of a module path
|
||||
mismatch.
|
||||
"""
|
||||
try:
|
||||
return import_module("ragas.metrics")
|
||||
except ImportError:
|
||||
return import_module("ragas.metrics.collections")
|
||||
|
||||
|
||||
class _NoopRagasEmbeddings:
|
||||
"""Placeholder embeddings for ragas metrics whose embedding branch is disabled.
|
||||
|
||||
ragas eagerly injects a default embeddings backend for any metric that
|
||||
subclasses ``MetricWithEmbeddings``. For answer_correctness we currently
|
||||
disable the semantic-similarity weight, so no real embedding call should
|
||||
happen. Supplying this placeholder keeps ragas from constructing its
|
||||
default OpenAI embeddings client during setup.
|
||||
"""
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
del text
|
||||
return [0.0]
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[0.0] for _ in texts]
|
||||
|
||||
|
||||
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
|
||||
"""Extract the user-facing input string from the inputs dict."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
return str(inputs.get("prompt", ""))
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return str(inputs.get("query", ""))
|
||||
case _:
|
||||
return str(next(iter(inputs.values()), "")) if inputs else ""
|
||||
@ -1,165 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
try:
|
||||
from ragas.llms.base import BaseRagasLLM
|
||||
except ImportError:
|
||||
class BaseRagasLLM: # type: ignore[no-redef]
|
||||
"""Lightweight shim so the module stays importable without ragas installed."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
del args, kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_temperature(n: int) -> float:
|
||||
return 0.3 if n > 1 else 1e-8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyModelWrapper(BaseRagasLLM):
|
||||
"""Bridge Dify model invocation to ragas and fallback LLM-as-judge flows.
|
||||
|
||||
ragas can accept a custom ``BaseRagasLLM`` instance. Using one here keeps
|
||||
evaluation requests on Dify's provider stack instead of falling back to
|
||||
ragas' default OpenAI factory, which would require standalone environment
|
||||
credentials and bypass tenant-scoped model configuration.
|
||||
"""
|
||||
|
||||
model_provider: str
|
||||
model_name: str
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
|
||||
def __init__(self, model_provider: str, model_name: str, tenant_id: str, user_id: str | None = None):
|
||||
super().__init__()
|
||||
self.model_provider = model_provider
|
||||
self.model_name = model_name
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
|
||||
def _get_model_instance(self) -> Any:
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_manager
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
model_manager = create_plugin_model_manager(tenant_id=self.tenant_id, user_id=self.user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.model_provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=self.model_name,
|
||||
)
|
||||
return model_instance
|
||||
|
||||
def invoke(self, prompt: str) -> str:
|
||||
"""Invoke the configured Dify model with a plain-text evaluation prompt."""
|
||||
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
model_instance = self._get_model_instance()
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
|
||||
UserPromptMessage(content=prompt),
|
||||
],
|
||||
model_parameters={"temperature": 0.0},
|
||||
stream=False,
|
||||
)
|
||||
return result.message.content
|
||||
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: PromptValue,
|
||||
n: int = 1,
|
||||
temperature: float = 1e-8,
|
||||
stop: list[str] | None = None,
|
||||
callbacks: Any = None,
|
||||
) -> LLMResult:
|
||||
"""Implement ragas' sync LLM interface on top of Dify's model runtime."""
|
||||
del callbacks # Dify's invocation path does not currently use LangChain callbacks here.
|
||||
prompt_messages = _convert_prompt_value(prompt)
|
||||
model_instance = self._get_model_instance()
|
||||
|
||||
generations: list[list[ChatGeneration]] = [[]]
|
||||
completions = max(1, n)
|
||||
for _ in range(completions):
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": temperature},
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
text = result.message.content
|
||||
generations[0].append(
|
||||
ChatGeneration(
|
||||
text=text,
|
||||
message=AIMessage(content=text, response_metadata={"finish_reason": "stop"}),
|
||||
generation_info={"finish_reason": "stop"},
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def agenerate_text(
|
||||
self,
|
||||
prompt: PromptValue,
|
||||
n: int = 1,
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
callbacks: Any = None,
|
||||
) -> LLMResult:
|
||||
"""Async ragas hook backed by the sync Dify invocation path."""
|
||||
return await asyncio.to_thread(
|
||||
self.generate_text,
|
||||
prompt,
|
||||
n,
|
||||
self.get_temperature(n) if temperature is None else temperature,
|
||||
stop,
|
||||
callbacks,
|
||||
)
|
||||
|
||||
|
||||
def _convert_prompt_value(prompt: PromptValue) -> list[Any]:
|
||||
"""Translate LangChain prompt values into graphon prompt messages."""
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
prompt_messages: list[Any] = []
|
||||
for message in prompt.to_messages():
|
||||
content = _message_content_to_text(message)
|
||||
if isinstance(message, SystemMessage):
|
||||
prompt_messages.append(SystemPromptMessage(content=content))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(AssistantPromptMessage(content=content))
|
||||
elif isinstance(message, HumanMessage):
|
||||
prompt_messages.append(UserPromptMessage(content=content))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=content))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _message_content_to_text(message: BaseMessage) -> str:
|
||||
"""Flatten LangChain message content into a plain-text string for Dify."""
|
||||
if isinstance(message.content, str):
|
||||
return message.content
|
||||
|
||||
if isinstance(message.content, list):
|
||||
parts: list[str] = []
|
||||
for block in message.content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
text = block.get("text")
|
||||
if text:
|
||||
parts.append(str(text))
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
return str(message.content or "")
|
||||
@ -1,160 +0,0 @@
|
||||
"""Judgment condition processor for evaluation metrics.
|
||||
|
||||
Evaluates pass/fail judgment conditions against evaluation metric values.
|
||||
Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to
|
||||
look up the metric value, then delegates the actual comparison to the
|
||||
workflow condition engine (``graphon.utils.condition.processor``).
|
||||
|
||||
The processor is intentionally decoupled from evaluation frameworks and
|
||||
runners. It operates on plain ``dict`` mappings and can be invoked
|
||||
anywhere that already has per-item metric results.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.evaluation.entities.judgment_entity import (
|
||||
JudgmentCondition,
|
||||
JudgmentConditionResult,
|
||||
JudgmentConfig,
|
||||
JudgmentResult,
|
||||
)
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
|
||||
|
||||
|
||||
class JudgmentProcessor:
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
metric_values: dict[tuple[str, str], Any],
|
||||
config: JudgmentConfig,
|
||||
) -> JudgmentResult:
|
||||
"""Evaluate all judgment conditions against the given metric values.
|
||||
|
||||
Args:
|
||||
metric_values: Mapping of ``(node_id, metric_name)`` → metric
|
||||
value (e.g. ``{("node_abc", "faithfulness"): 0.85}``).
|
||||
config: The judgment configuration with logical_operator and
|
||||
conditions.
|
||||
|
||||
Returns:
|
||||
JudgmentResult with overall pass/fail and per-condition details.
|
||||
"""
|
||||
if not config.conditions:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=[],
|
||||
)
|
||||
|
||||
condition_results: list[JudgmentConditionResult] = []
|
||||
|
||||
for condition in config.conditions:
|
||||
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
|
||||
condition_results.append(result)
|
||||
|
||||
if config.logical_operator == "and" and not result.passed:
|
||||
return JudgmentResult(
|
||||
passed=False,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
if config.logical_operator == "or" and result.passed:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
if config.logical_operator == "and":
|
||||
final_passed = all(r.passed for r in condition_results)
|
||||
else:
|
||||
final_passed = any(r.passed for r in condition_results)
|
||||
|
||||
return JudgmentResult(
|
||||
passed=final_passed,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_single_condition(
|
||||
metric_values: dict[tuple[str, str], Any],
|
||||
condition: JudgmentCondition,
|
||||
) -> JudgmentConditionResult:
|
||||
"""Evaluate a single judgment condition.
|
||||
|
||||
Steps:
|
||||
1. Extract ``(node_id, metric_name)`` from ``variable_selector``.
|
||||
2. Look up the metric value from ``metric_values``.
|
||||
3. Delegate comparison to the workflow condition engine.
|
||||
"""
|
||||
selector = condition.variable_selector
|
||||
if len(selector) < 2:
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"variable_selector must have at least 2 elements, got {len(selector)}",
|
||||
)
|
||||
|
||||
node_id, metric_name = selector[0], selector[1]
|
||||
actual_value = metric_values.get((node_id, metric_name))
|
||||
|
||||
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results",
|
||||
)
|
||||
|
||||
try:
|
||||
expected = condition.value
|
||||
# Numeric operators need the actual value coerced to int/float
|
||||
# so that the workflow engine's numeric assertions work correctly.
|
||||
coerced_actual: object = actual_value
|
||||
if (
|
||||
condition.comparison_operator in {"=", "≠", ">", "<", "≥", "≤"}
|
||||
and actual_value is not None
|
||||
and not isinstance(actual_value, (int, float, bool))
|
||||
):
|
||||
coerced_actual = float(actual_value)
|
||||
|
||||
passed = _evaluate_condition(
|
||||
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
|
||||
value=coerced_actual,
|
||||
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
|
||||
)
|
||||
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=expected,
|
||||
actual_value=actual_value,
|
||||
passed=passed,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Judgment condition evaluation failed for [%s, %s]: %s",
|
||||
node_id,
|
||||
metric_name,
|
||||
str(e),
|
||||
)
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=actual_value,
|
||||
passed=False,
|
||||
error=str(e),
|
||||
)
|
||||
@ -1,52 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, App, CustomizedSnippet, TenantAccountJoin
|
||||
|
||||
|
||||
def get_service_account_for_app(session: Session, app_id: str) -> Account:
|
||||
"""Get the creator account for an app with tenant context set up.
|
||||
|
||||
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
|
||||
"""
|
||||
app = session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == app.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
|
||||
|
||||
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
|
||||
"""Get the creator account for a snippet with tenant context set up.
|
||||
|
||||
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
|
||||
"""
|
||||
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
|
||||
if not snippet:
|
||||
raise ValueError(f"Snippet with id {snippet_id} not found")
|
||||
|
||||
if not snippet.created_by:
|
||||
raise ValueError(f"Snippet with id {snippet_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for snippet {snippet_id}")
|
||||
|
||||
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
@ -1,66 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for agent evaluation: collects tool calls and final output."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute agent evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_agent(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_agent_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from agent NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,55 +0,0 @@
|
||||
"""Base evaluation runner.
|
||||
|
||||
Provides the abstract interface for metric computation. Each concrete runner
|
||||
(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics``
|
||||
to compute scores for a specific node type.
|
||||
|
||||
Orchestration (merging results from multiple runners, applying judgment, and
|
||||
persisting to the database) is handled by the evaluation task, not the runner.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
NodeInfo,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationRunner(ABC):
|
||||
"""Abstract base class for evaluation runners.
|
||||
|
||||
Runners are stateless metric calculators: they receive node execution
|
||||
results and a metric specification, then return scored results. They
|
||||
do **not** touch the database or apply judgment logic.
|
||||
"""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
self.evaluation_instance = evaluation_instance
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics on the collected results.
|
||||
|
||||
The returned ``EvaluationItemResult.index`` values are positional
|
||||
(0-based) and correspond to the order of *node_run_result_list*.
|
||||
The caller is responsible for mapping them back to the original
|
||||
dataset indices.
|
||||
"""
|
||||
...
|
||||
@ -1,107 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for LLM evaluation: extracts prompts/outputs then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Use the evaluation instance to compute LLM metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list, dataset_items, node_info)
|
||||
return self.evaluation_instance.evaluate_llm(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(
|
||||
items: list[NodeRunResult],
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemInput]:
|
||||
"""Create new items from NodeRunResult for ragas evaluation.
|
||||
|
||||
Extracts prompts from process_data and concatenates them into a single
|
||||
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
|
||||
The last assistant message in outputs is used as the actual output.
|
||||
"""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
prompts = item.process_data.get("prompts", [])
|
||||
prompt = _format_prompts(prompts)
|
||||
output = _extract_llm_output(item.outputs)
|
||||
dataset_item = dataset_items[i] if dataset_items and i < len(dataset_items) else None
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs={"prompt": prompt},
|
||||
output=output,
|
||||
expected_output=dataset_item.get_expected_output_for_node(node_info.title) if dataset_item else None,
|
||||
context=_extract_context_blocks(prompts),
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
|
||||
"""Concatenate a list of prompt messages into a single string for evaluation.
|
||||
|
||||
Each message is formatted as "role: text" and joined with newlines.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for msg in prompts:
|
||||
role = msg.get("role", "unknown")
|
||||
text = msg.get("text", "")
|
||||
parts.append(f"{role}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the LLM output text from NodeRunResult.outputs."""
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
|
||||
|
||||
def _extract_context_blocks(prompts: list[dict[str, Any]]) -> list[str] | None:
|
||||
"""Extract tagged context blocks from rendered prompts.
|
||||
|
||||
Evaluation only treats prompt content wrapped in ``<context>...</context>``
|
||||
as retrieved evidence. This keeps faithfulness-style metrics opt-in and
|
||||
avoids guessing which arbitrary prompt text should be considered context.
|
||||
"""
|
||||
prompt_text = "\n".join(str(prompt.get("text", "")) for prompt in prompts)
|
||||
matches = re.findall(r"<context>(.*?)</context>", prompt_text, flags=re.DOTALL)
|
||||
contexts = [match.strip() for match in matches if match.strip()]
|
||||
return contexts or None
|
||||
@ -1,67 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute retrieval evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
|
||||
merged_items = []
|
||||
for i, node_result in enumerate(node_run_result_list):
|
||||
outputs = node_result.outputs
|
||||
query = self._extract_query(dict(node_result.inputs))
|
||||
result_list = outputs.get("result", [])
|
||||
contexts = [item.get("content", "") for item in result_list if item.get("content")]
|
||||
output = "\n---\n".join(contexts)
|
||||
dataset_item = dataset_items[i] if dataset_items and i < len(dataset_items) else None
|
||||
|
||||
merged_items.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs={"query": query},
|
||||
output=output,
|
||||
expected_output=dataset_item.get_expected_output_for_node(node_info.title) if dataset_item else None,
|
||||
context=contexts,
|
||||
)
|
||||
)
|
||||
|
||||
return self.evaluation_instance.evaluate_retrieval(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(inputs: dict[str, Any]) -> str:
|
||||
for key in ("query", "question", "input", "text"):
|
||||
if key in inputs:
|
||||
return str(inputs[key])
|
||||
values = list(inputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,72 +0,0 @@
|
||||
"""Runner for Snippet evaluation.
|
||||
|
||||
Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from
|
||||
the evaluation instance for metric computation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnippetEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for snippet evaluation: evaluates a published Snippet workflow."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics for snippet outputs."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_snippet_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from snippet NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,66 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
dataset_items: list[EvaluationDatasetInput] | None = None,
|
||||
node_info: NodeInfo | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute workflow evaluation metrics (end-to-end)."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_workflow_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from workflow NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -2,6 +2,7 @@
|
||||
Credential utility functions for checking credential existence and policy compliance.
|
||||
"""
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
|
||||
|
||||
@ -39,6 +40,16 @@ def is_credential_exists(credential_id: str, credential_type: "PluginCredentialT
|
||||
return False
|
||||
|
||||
|
||||
def runtime_check_credential_policy_compliance(
|
||||
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
|
||||
):
|
||||
if dify_config.ENTERPRISE_DISABLE_RUNTIME_CREDENTIAL_CHECK:
|
||||
return
|
||||
check_credential_policy_compliance(
|
||||
credential_id=credential_id, provider=provider, credential_type=credential_type, check_existence=check_existence
|
||||
)
|
||||
|
||||
|
||||
def check_credential_policy_compliance(
|
||||
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
|
||||
) -> None:
|
||||
|
||||
@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
CONVERSATION_TITLE_PROMPT,
|
||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
@ -217,8 +215,8 @@ class LLMGenerator:
|
||||
else:
|
||||
# Default-model generation keeps the built-in suggested-questions tuning.
|
||||
model_parameters = {
|
||||
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
"max_tokens": 2560,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
stop = []
|
||||
|
||||
@ -437,7 +435,7 @@ class LLMGenerator:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Runtime type check since pyright has issues with the overload
|
||||
# Runtime type check for overload narrowing.
|
||||
if not isinstance(result, LLMResult):
|
||||
raise TypeError("Expected LLMResult when stream=False")
|
||||
response = result
|
||||
|
||||
@ -104,10 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
|
||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
|
||||
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, model_validator
|
||||
from pydantic.networks import AnyUrl, UrlConstraints
|
||||
|
||||
"""
|
||||
@ -173,7 +173,21 @@ class JSONRPCError(BaseModel):
|
||||
|
||||
|
||||
class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]):
|
||||
pass
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _select_message_type(
|
||||
cls, value: Any
|
||||
) -> JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError | Any:
|
||||
if isinstance(value, dict):
|
||||
if "result" in value:
|
||||
return JSONRPCResponse.model_validate(value)
|
||||
if "error" in value:
|
||||
return JSONRPCError.model_validate(value)
|
||||
if "method" in value:
|
||||
if "id" in value:
|
||||
return JSONRPCRequest.model_validate(value)
|
||||
return JSONRPCNotification.model_validate(value)
|
||||
return value
|
||||
|
||||
|
||||
class EmptyResult(Result):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user