mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 12:09:27 +08:00
Compare commits
211 Commits
purity
...
feat/trigg
| Author | SHA1 | Date | |
|---|---|---|---|
| 2dbf76299b | |||
| c5de91ba94 | |||
| bc1e6e011b | |||
| 906028b1fb | |||
| 034602969f | |||
| 4ca14bfdad | |||
| 59f56d8c94 | |||
| e2ebb9ea38 | |||
| d4191d1f2f | |||
| 63d26f0478 | |||
| 13fe2ca8fe | |||
| eae65e55ce | |||
| 0edf06329f | |||
| 6943a379c9 | |||
| e49534b70c | |||
| 344616ca2f | |||
| 1264e7d4f6 | |||
| 4f45978cd9 | |||
| 5a0bf8e028 | |||
| ffa163a8a8 | |||
| 8f86f5749d | |||
| 00d3bf15f3 | |||
| 7196c09e9d | |||
| fadd9e0bf4 | |||
| d8b4bbe067 | |||
| 24611e375a | |||
| ccec582cea | |||
| b2e4107c17 | |||
| 87aa070486 | |||
| 0e287a9c93 | |||
| 21230a8eb2 | |||
| 85cda47c70 | |||
| 8141f53af5 | |||
| 5a6cb0d887 | |||
| 7dadb33003 | |||
| 26e7677595 | |||
| 814b0e1fe8 | |||
| a173dc5c9d | |||
| a567facf2b | |||
| e76d80defe | |||
| b5a7e64e19 | |||
| 4a17025467 | |||
| bd1fcd3525 | |||
| b283b10d3e | |||
| 0cb0cea167 | |||
| ecb22226d6 | |||
| 8635aacb46 | |||
| bdd85b36a4 | |||
| a0c7713494 | |||
| ee68a685a7 | |||
| abf4955c26 | |||
| 74340e3c04 | |||
| b98b389baf | |||
| c78bd492af | |||
| 6857bb4406 | |||
| dcf3ee6982 | |||
| 76850749e4 | |||
| 91e5e33440 | |||
| 11e55088c9 | |||
| 57c0bc9fb6 | |||
| c3ebb22a4b | |||
| 1562d00037 | |||
| e9e843b27d | |||
| ec33b9908e | |||
| 67004368d9 | |||
| 50bff270b6 | |||
| bd5cf1c272 | |||
| d22404994a | |||
| 9898730cc5 | |||
| b0f1e55a87 | |||
| 6566824807 | |||
| 9249a2af0d | |||
| 112fc3b1d1 | |||
| 37299b3bd7 | |||
| 8f65ce995a | |||
| 4a743e6dc1 | |||
| 07dda61929 | |||
| 0d8438ef40 | |||
| 96bb638969 | |||
| e74962272e | |||
| 5a15419baf | |||
| e8403977b9 | |||
| add2ca85f2 | |||
| fbb7b02e90 | |||
| 249b62c9de | |||
| b433322e8d | |||
| 1c8850fc95 | |||
| dc16f1b65a | |||
| ff30395dc1 | |||
| 8e600f3302 | |||
| 5a1e0a8379 | |||
| 2a3ce6baa9 | |||
| 01b2f9cff6 | |||
| ac38614171 | |||
| eb95c5cd07 | |||
| a799b54b9e | |||
| 98ba0236e6 | |||
| b6c552df07 | |||
| e2827e475d | |||
| 58cbd337b5 | |||
| a91e59d544 | |||
| 814787677a | |||
| 85caa5bd0c | |||
| e04083fc0e | |||
| cf532e5e0d | |||
| c097fc2c48 | |||
| 0371d71409 | |||
| 81ef7343d4 | |||
| 8e4b59c90c | |||
| 68f73410fc | |||
| 88af8ed374 | |||
| 015f82878e | |||
| 3874e58dc2 | |||
| 9f8c159583 | |||
| d8f6f9ce19 | |||
| eab03e63d4 | |||
| 461829274a | |||
| e751c0c535 | |||
| 1fffc79c32 | |||
| 83fab4bc19 | |||
| f60e28d2f5 | |||
| a62d7aa3ee | |||
| cc84a45244 | |||
| 5cf3d24018 | |||
| 4bdbe617fe | |||
| 33c867fd8c | |||
| 2013ceb9d2 | |||
| 7120c6414c | |||
| 5ce7b2d98d | |||
| cb82198271 | |||
| 5e5ffaa416 | |||
| 4b253e1f73 | |||
| dd929dbf0e | |||
| 97a9d34e96 | |||
| 602070ec9c | |||
| afd8989150 | |||
| 694197a701 | |||
| 2f08306695 | |||
| 6acc77d86d | |||
| 5ddd5e49ee | |||
| 72f9e77368 | |||
| a46c9238fa | |||
| 87120ad4ac | |||
| 7544b5ec9a | |||
| ff4a62d1e7 | |||
| 41daa51988 | |||
| d522350c99 | |||
| 1d1bb9451e | |||
| 1fce1a61d4 | |||
| 883a6caf96 | |||
| a239c39f09 | |||
| e925a8ab99 | |||
| bccaf939e6 | |||
| 676648e0b3 | |||
| 4ae19e6dde | |||
| 4d0ff5c281 | |||
| 327b354cc2 | |||
| 6d307cc9fc | |||
| adc7134af5 | |||
| 10f19cd0c2 | |||
| 9ed45594c6 | |||
| c138f4c3a6 | |||
| a35be05790 | |||
| 60b5ed8e5d | |||
| d8ddbc4d87 | |||
| 19c0fc85e2 | |||
| a58df35ead | |||
| 9789bd02d8 | |||
| d94e54923f | |||
| 64c7be59b7 | |||
| 89ad6ad902 | |||
| 4f73bc9693 | |||
| add6b79231 | |||
| c90dad566f | |||
| 5cbe6bf8f8 | |||
| 4ef6ff217e | |||
| 87abfbf515 | |||
| 73e65fd838 | |||
| e53edb0fc2 | |||
| 17908fbf6b | |||
| 3dae108f84 | |||
| 5bbf685035 | |||
| a63d1e87b1 | |||
| 7129de98cd | |||
| 2984dbc0df | |||
| 392db7f611 | |||
| 5a427b8daa | |||
| 18f2e6f166 | |||
| e78903302f | |||
| 4084ade86c | |||
| 6b0d919dbd | |||
| a7b558b38b | |||
| 6aed7e3ff4 | |||
| 8e93a8a2e2 | |||
| e38a86e37b | |||
| 392e3530bf | |||
| 833c902b2b | |||
| 6eaea64b3f | |||
| 5303b50737 | |||
| 6acbcfe679 | |||
| 16ef5ebb97 | |||
| acfb95f9c2 | |||
| aacea166d7 | |||
| f7bb3b852a | |||
| d4ff1e031a | |||
| 6a3d135d49 | |||
| 5c4bf7aabd | |||
| e9c7dc7464 | |||
| 74ad21b145 | |||
| f214eeb7b1 | |||
| ae25f90f34 |
@ -1,4 +1,4 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@ -2,6 +2,8 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -22,7 +24,7 @@ jobs:
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
uv run ruff format ..
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@ -8,6 +8,8 @@ on:
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "deploy/rag-dev"
|
||||
- "feat/rag-2"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
7
.github/workflows/deploy-dev.yml
vendored
7
.github/workflows/deploy-dev.yml
vendored
@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/dev"
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -12,12 +12,13 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
6
.github/workflows/style.yml
vendored
6
.github/workflows/style.yml
vendored
@ -12,7 +12,6 @@ permissions:
|
||||
statuses: write
|
||||
contents: read
|
||||
|
||||
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
@ -44,6 +43,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Import Linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
@ -99,7 +102,6 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -230,4 +230,6 @@ api/.env.backup
|
||||
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
scripts/stress-test/reports/
|
||||
# mcp
|
||||
.serena
|
||||
|
||||
88
AGENTS.md
Normal file
88
AGENTS.md
Normal file
@ -0,0 +1,88 @@
|
||||
# AGENTS.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend (API)
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Testing
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
### Python
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||
89
CLAUDE.md
89
CLAUDE.md
@ -1,89 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend (API)
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Testing
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
### Python
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
@ -76,6 +76,7 @@ DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
@ -328,7 +329,7 @@ MATRIXONE_DATABASE=dify
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
USING_UGC_INDEX=False
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
@ -435,6 +436,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -461,6 +465,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
@ -503,6 +517,12 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
|
||||
105
api/.importlinter
Normal file
105
api/.importlinter
Normal file
@ -0,0 +1,105 @@
|
||||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handler
|
||||
graph_traversal
|
||||
graph_state_manager
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:worker-management]
|
||||
name = Worker Management
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.orchestration
|
||||
core.workflow.graph_engine.command_processing
|
||||
core.workflow.graph_engine.event_management
|
||||
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
@ -5,7 +5,7 @@ line-length = 120
|
||||
quote-style = "double"
|
||||
|
||||
[lint]
|
||||
preview = false
|
||||
preview = true
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
@ -45,6 +45,7 @@ select = [
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum
|
||||
]
|
||||
|
||||
ignore = [
|
||||
@ -64,6 +65,7 @@ ignore = [
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B901", # allow return in yield
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
27
api/app.py
27
api/app.py
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@ -17,20 +16,20 @@ else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
22
api/celery_entrypoint.py
Normal file
22
api/celery_entrypoint.py
Normal file
@ -0,0 +1,22 @@
|
||||
import logging
|
||||
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
_log("gRPC patched with gevent.")
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
_log("psycopg2 patched with gevent.")
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
__all__ = ["app", "celery"]
|
||||
516
api/commands.py
516
api/commands.py
@ -13,7 +13,9 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
@ -23,19 +25,25 @@ from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -953,7 +961,7 @@ def clear_orphaned_file_records(force: bool):
|
||||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
@ -1219,6 +1227,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@ -1245,15 +1302,17 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
Count orphaned draft variables by app.
|
||||
Count orphaned draft variables by app, including associated file counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about orphaned variables
|
||||
Dictionary with statistics about orphaned variables and files
|
||||
"""
|
||||
query = """
|
||||
# Count orphaned variables by app
|
||||
variables_query = """
|
||||
SELECT
|
||||
wdv.app_id,
|
||||
COUNT(*) as variable_count
|
||||
COUNT(*) as variable_count,
|
||||
COUNT(wdv.file_id) as file_count
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
@ -1263,14 +1322,21 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(query))
|
||||
orphaned_by_app = {row[0]: row[1] for row in result}
|
||||
result = conn.execute(sa.text(variables_query))
|
||||
orphaned_by_app = {}
|
||||
total_files = 0
|
||||
|
||||
total_orphaned = sum(orphaned_by_app.values())
|
||||
for row in result:
|
||||
app_id, variable_count, file_count = row
|
||||
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
|
||||
total_files += file_count
|
||||
|
||||
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
|
||||
app_count = len(orphaned_by_app)
|
||||
|
||||
return {
|
||||
"total_orphaned_variables": total_orphaned,
|
||||
"total_orphaned_files": total_files,
|
||||
"orphaned_app_count": app_count,
|
||||
"orphaned_by_app": orphaned_by_app,
|
||||
}
|
||||
@ -1299,6 +1365,7 @@ def cleanup_orphaned_draft_variables(
|
||||
stats = _count_orphaned_draft_variables()
|
||||
|
||||
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
||||
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
|
||||
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
||||
|
||||
if stats["total_orphaned_variables"] == 0:
|
||||
@ -1307,10 +1374,10 @@ def cleanup_orphaned_draft_variables(
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Would delete the following:")
|
||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
|
||||
for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
|
||||
:10
|
||||
]: # Show top 10
|
||||
logger.info(" App %s: %s variables", app_id, count)
|
||||
logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
|
||||
if len(stats["orphaned_by_app"]) > 10:
|
||||
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
||||
return
|
||||
@ -1319,7 +1386,8 @@ def cleanup_orphaned_draft_variables(
|
||||
if not force:
|
||||
click.confirm(
|
||||
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
||||
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
|
||||
f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
|
||||
f"from {stats['orphaned_app_count']} apps?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
@ -1352,3 +1420,425 @@ def cleanup_orphaned_draft_variables(
|
||||
continue
|
||||
|
||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||
|
||||
|
||||
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_datasource_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup datasource oauth client
|
||||
"""
|
||||
provider_id = DatasourceProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||
oauth_client = DatasourceOauthParamConfig(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
system_credentials=client_params_dict,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
try:
|
||||
installer_manager = PluginInstaller()
|
||||
plugin_migration = PluginMigration()
|
||||
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
tenant_id = notion_credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
tenant_id = firecrawl_credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
tenant_id = jina_credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||
click.echo(
|
||||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_rag_pipeline_plugins(input_file, output_file, workers):
|
||||
"""
|
||||
Install rag pipeline plugins
|
||||
"""
|
||||
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
|
||||
plugin_migration = PluginMigration()
|
||||
plugin_migration.install_rag_pipeline_plugins(
|
||||
input_file,
|
||||
output_file,
|
||||
workers,
|
||||
)
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
"migrate-oss",
|
||||
help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).",
|
||||
)
|
||||
@click.option(
|
||||
"--path",
|
||||
"paths",
|
||||
multiple=True,
|
||||
help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files,"
|
||||
" tools, website_files, keyword_files, ops_trace",
|
||||
)
|
||||
@click.option(
|
||||
"--source",
|
||||
type=click.Choice(["local", "opendal"], case_sensitive=False),
|
||||
default="opendal",
|
||||
show_default=True,
|
||||
help="Source storage type to read from",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading")
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts")
|
||||
@click.option(
|
||||
"--update-db/--no-update-db",
|
||||
default=True,
|
||||
help="Update upload_files.storage_type from source type to current storage after migration",
|
||||
)
|
||||
def migrate_oss(
|
||||
paths: tuple[str, ...],
|
||||
source: str,
|
||||
overwrite: bool,
|
||||
dry_run: bool,
|
||||
force: bool,
|
||||
update_db: bool,
|
||||
):
|
||||
"""
|
||||
Copy all files under selected prefixes from a source storage
|
||||
(Local filesystem or OpenDAL-backed) into the currently configured
|
||||
destination storage backend, then optionally update DB records.
|
||||
|
||||
Expected usage: set STORAGE_TYPE (and its credentials) to your target backend.
|
||||
"""
|
||||
# Ensure target storage is not local/opendal
|
||||
if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n"
|
||||
"Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n"
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Default paths if none specified
|
||||
default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace")
|
||||
path_list = list(paths) if paths else list(default_paths)
|
||||
is_source_local = source.lower() == "local"
|
||||
|
||||
click.echo(click.style("Preparing migration to target storage.", fg="yellow"))
|
||||
click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white"))
|
||||
if is_source_local:
|
||||
src_root = dify_config.STORAGE_LOCAL_PATH
|
||||
click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white"))
|
||||
else:
|
||||
click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white"))
|
||||
click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white"))
|
||||
click.echo("")
|
||||
|
||||
if not force:
|
||||
click.confirm("Proceed with migration?", abort=True)
|
||||
|
||||
# Instantiate source storage
|
||||
try:
|
||||
if is_source_local:
|
||||
src_root = dify_config.STORAGE_LOCAL_PATH
|
||||
source_storage = OpenDALStorage(scheme="fs", root=src_root)
|
||||
else:
|
||||
source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
total_files = 0
|
||||
copied_files = 0
|
||||
skipped_files = 0
|
||||
errored_files = 0
|
||||
copied_upload_file_keys: list[str] = []
|
||||
|
||||
for prefix in path_list:
|
||||
click.echo(click.style(f"Scanning source path: {prefix}", fg="white"))
|
||||
try:
|
||||
keys = source_storage.scan(path=prefix, files=True, directories=False)
|
||||
except FileNotFoundError:
|
||||
click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow"))
|
||||
continue
|
||||
except NotImplementedError:
|
||||
click.echo(click.style(" -> Source storage does not support scanning.", fg="red"))
|
||||
return
|
||||
except Exception as e:
|
||||
click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white"))
|
||||
|
||||
for key in keys:
|
||||
total_files += 1
|
||||
|
||||
# check destination existence
|
||||
if not overwrite:
|
||||
try:
|
||||
if storage.exists(key):
|
||||
skipped_files += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
# existence check failures should not block migration attempt
|
||||
# but should be surfaced to user as a warning for visibility
|
||||
click.echo(
|
||||
click.style(
|
||||
f" -> Warning: failed target existence check for {key}: {str(e)}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
copied_files += 1
|
||||
continue
|
||||
|
||||
# read from source and write to destination
|
||||
try:
|
||||
data = source_storage.load_once(key)
|
||||
except FileNotFoundError:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Missing on source: {key}", fg="yellow"))
|
||||
continue
|
||||
except Exception as e:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
try:
|
||||
storage.save(key, data)
|
||||
copied_files += 1
|
||||
if prefix == "upload_files":
|
||||
copied_upload_file_keys.append(key)
|
||||
except Exception as e:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo("")
|
||||
click.echo(click.style("Migration summary:", fg="yellow"))
|
||||
click.echo(click.style(f" Total: {total_files}", fg="white"))
|
||||
click.echo(click.style(f" Copied: {copied_files}", fg="green"))
|
||||
click.echo(click.style(f" Skipped: {skipped_files}", fg="white"))
|
||||
if errored_files:
|
||||
click.echo(click.style(f" Errors: {errored_files}", fg="red"))
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry-run complete. No changes were made.", fg="green"))
|
||||
return
|
||||
|
||||
if errored_files:
|
||||
click.echo(
|
||||
click.style(
|
||||
"Some files failed to migrate. Review errors above before updating DB records.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
if update_db and not force:
|
||||
if not click.confirm("Proceed to update DB storage_type despite errors?", default=False):
|
||||
update_db = False
|
||||
|
||||
# Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files)
|
||||
if update_db:
|
||||
if not copied_upload_file_keys:
|
||||
click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow"))
|
||||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
dify_config = DifyConfig() # type: ignore
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import (
|
||||
@ -153,6 +154,17 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@ -505,6 +517,22 @@ class UpdateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 100KB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
|
||||
100000,
|
||||
description="maximum length for string to trigger tuncation, measure in number of characters",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
|
||||
1000,
|
||||
description="maximum length for array to trigger truncation.",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for workflow execution
|
||||
@ -535,6 +563,28 @@ class WorkflowConfig(BaseSettings):
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
@ -673,11 +723,35 @@ class ToolConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TemplateMode(StrEnum):
|
||||
# unsafe mode allows flexible operations in templates, but may cause security vulnerabilities
|
||||
UNSAFE = "unsafe"
|
||||
|
||||
# sandbox mode restricts some unsafe operations like accessing __class__.
|
||||
# however, it is still not 100% safe, for example, cpu exploitation can happen.
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
# templating is disabled
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class MailConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for email services
|
||||
"""
|
||||
|
||||
MAIL_TEMPLATING_MODE: TemplateMode = Field(
|
||||
description="Template mode for email services",
|
||||
default=TemplateMode.SANDBOX,
|
||||
)
|
||||
|
||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||
description="""
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Only available in sandbox mode.""",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAIL_TYPE: str | None = Field(
|
||||
description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
|
||||
default=None,
|
||||
@ -887,6 +961,22 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -1010,6 +1100,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
@ -1041,5 +1132,6 @@ class FeatureConfig(
|
||||
CeleryBeatConfig,
|
||||
CeleryScheduleTasksConfig,
|
||||
WorkflowLogConfig,
|
||||
WorkflowVariableTruncationConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -220,11 +220,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
||||
@ -187,6 +187,11 @@ class DatabaseConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Number of seconds to wait for a connection from the pool before raising a timeout error.",
|
||||
default=30,
|
||||
)
|
||||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count() or 1,
|
||||
@ -216,6 +221,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
"pool_reset_on_return": None,
|
||||
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -19,15 +19,15 @@ class LindormConfig(BaseSettings):
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
DEFAULT_INDEX_TYPE: str | None = Field(
|
||||
LINDORM_INDEX_TYPE: str | None = Field(
|
||||
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||
default="hnsw",
|
||||
)
|
||||
DEFAULT_DISTANCE_TYPE: str | None = Field(
|
||||
LINDORM_DISTANCE_TYPE: str | None = Field(
|
||||
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||
)
|
||||
USING_UGC_INDEX: bool | None = Field(
|
||||
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||
default=False,
|
||||
LINDORM_USING_UGC: bool | None = Field(
|
||||
description="Using UGC index will store indexes with the same IndexType/Dimension in a single big index.",
|
||||
default=True,
|
||||
)
|
||||
LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0)
|
||||
|
||||
@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
|
||||
@ -5,9 +5,12 @@ from typing import TYPE_CHECKING
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
"""
|
||||
@ -32,3 +35,19 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -61,6 +61,7 @@ from . import (
|
||||
init_validate,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
)
|
||||
|
||||
@ -86,6 +87,7 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -114,6 +116,15 @@ from .datasets import (
|
||||
metadata,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
datasource_auth,
|
||||
datasource_content_preview,
|
||||
rag_pipeline,
|
||||
rag_pipeline_datasets,
|
||||
rag_pipeline_draft_variable,
|
||||
rag_pipeline_import,
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
@ -213,6 +224,21 @@ api.add_resource(
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"activate",
|
||||
@ -238,6 +264,8 @@ __all__ = [
|
||||
"datasets",
|
||||
"datasets_document",
|
||||
"datasets_segments",
|
||||
"datasource_auth",
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"extension",
|
||||
@ -263,13 +291,20 @@ __all__ = [
|
||||
"parameter",
|
||||
"ping",
|
||||
"plugin",
|
||||
"rag_pipeline",
|
||||
"rag_pipeline_datasets",
|
||||
"rag_pipeline_draft_variable",
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"recommended_app",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
|
||||
@ -12,11 +12,15 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@ -195,19 +199,14 @@ class InstructionGenerateApi(Resource):
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
@ -261,6 +260,7 @@ class InstructionGenerateApi(Resource):
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -20,6 +20,8 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@ -34,6 +36,7 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.trigger_debug_service import TriggerDebugService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -513,7 +516,7 @@ class DraftWorkflowRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@api.doc("stop_workflow_task")
|
||||
@api.doc(description="Stop running workflow task")
|
||||
@ -536,7 +539,12 @@ class WorkflowTaskStopApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -682,7 +690,7 @@ class PublishedWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-block-configs")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@api.doc("get_default_block_configs")
|
||||
@api.doc(description="Get default block configurations for workflow")
|
||||
@ -708,7 +716,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-block-configs/<string:block_type>")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@api.doc("get_default_block_config")
|
||||
@api.doc(description="Get default block configuration by type")
|
||||
@ -791,7 +799,7 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/config")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@ -809,7 +817,7 @@ class WorkflowConfigApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.doc("get_all_published_workflows")
|
||||
@api.doc(description="Get all published workflows for an application")
|
||||
@ -865,7 +873,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<uuid:workflow_id>")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@api.doc("update_workflow_by_id")
|
||||
@api.doc(description="Update workflow by ID")
|
||||
@ -998,3 +1006,165 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={
|
||||
"app_id": "Application ID",
|
||||
"node_id": "Node ID"
|
||||
})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerNodeRequest",
|
||||
{
|
||||
"trigger_name": fields.String(required=True, description="Trigger name"),
|
||||
"subscription_id": fields.String(required=True, description="Subscription ID"),
|
||||
}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
try:
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
user_inputs = event.model_dump()
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query="",
|
||||
files=[],
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
"trigger_name": fields.String(required=True, description="Trigger name"),
|
||||
"subscription_id": fields.String(required=True, description="Subscription ID"),
|
||||
}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
workflow_args = {
|
||||
"inputs": event.model_dump(),
|
||||
"query": "",
|
||||
"files": [],
|
||||
}
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
workflow_args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
|
||||
@ -13,14 +13,16 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
@ -74,6 +76,22 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
"""Serialize full_content information for large variables."""
|
||||
if not variable.is_truncated():
|
||||
return None
|
||||
|
||||
variable_file = variable.variable_file
|
||||
assert variable_file is not None
|
||||
|
||||
return {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
@ -83,11 +101,13 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
full_content=fields.Raw(attribute=_serialize_full_content),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
|
||||
249
api/controllers/console/app/workflow_trigger.py
Normal file
249
api/controllers/console/app/workflow_trigger.py
Normal file
@ -0,0 +1,249 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account, AppMode
|
||||
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
|
||||
class PluginTriggerApi(Resource):
|
||||
"""Workflow Plugin Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def post(self, app_model):
|
||||
"""Create plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=False, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=False, location="json")
|
||||
parser.add_argument("trigger_name", type=str, required=False, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args["provider_id"],
|
||||
trigger_name=args["trigger_name"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def get(self, app_model):
|
||||
"""Get plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def put(self, app_model):
|
||||
"""Update plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def delete(self, app_model):
|
||||
"""Delete plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = args["node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.filter(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
# Add computed fields for marshal_with
|
||||
base_url = dify_config.SERVICE_API_URL
|
||||
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
|
||||
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -9,6 +11,8 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -19,6 +23,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@ -111,6 +116,18 @@ class DataSourceNotionListApi(Resource):
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
@ -134,31 +151,49 @@ class DataSourceNotionListApi(Resource):
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
if credential:
|
||||
datasource_runtime.runtime.credentials = credential
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
)
|
||||
try:
|
||||
pages = []
|
||||
workspace_info = {}
|
||||
for message in online_document_result:
|
||||
result = message.result
|
||||
for info in result:
|
||||
workspace_info = {
|
||||
"workspace_id": info.workspace_id,
|
||||
"workspace_name": info.workspace_name,
|
||||
"workspace_icon": info.workspace_icon,
|
||||
}
|
||||
for page in info.pages:
|
||||
page_info = {
|
||||
"page_id": page.page_id,
|
||||
"page_name": page.page_name,
|
||||
"type": page.type,
|
||||
"parent_id": page.parent_id,
|
||||
"is_bound": page.page_id in exist_page_ids,
|
||||
"page_icon": page.page_icon,
|
||||
}
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
@ -166,27 +201,25 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
@ -211,10 +244,12 @@ class DataSourceNotionApi(Resource):
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.wraps import (
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
@ -33,6 +32,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
@ -337,6 +337,15 @@ class DatasetApi(Resource):
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
@ -387,7 +396,7 @@ class DatasetApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -503,10 +512,12 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
@ -730,6 +741,19 @@ class DatasetApiDeleteApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
|
||||
class DatasetEnableApiApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-base-info")
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@api.doc("get_dataset_api_base_info")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
@ -53,6 +54,7 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
@ -499,6 +501,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -541,6 +544,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
@ -591,6 +595,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -714,7 +719,7 @@ class DocumentApi(DocumentResource):
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -910,6 +915,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
return {"result": "success", "message": "Document metadata updated."}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
class DocumentStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -946,6 +952,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
class DocumentPauseApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -979,6 +986,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
class DocumentRecoverApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1009,6 +1017,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
|
||||
class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1052,6 +1061,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
class DocumentRenameApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1075,6 +1085,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
return document
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1100,3 +1111,64 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
DocumentService.sync_website_document(dataset_id, document)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
"datasource_info": None,
|
||||
"datasource_type": None,
|
||||
"input_data": None,
|
||||
"datasource_node_id": None,
|
||||
}, 200
|
||||
return {
|
||||
"datasource_info": json.loads(log.datasource_info),
|
||||
"datasource_type": log.datasource_type,
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
api.add_resource(
|
||||
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
|
||||
)
|
||||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(
|
||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||
)
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
||||
@ -71,3 +71,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
||||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
||||
|
||||
@ -148,7 +148,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
||||
|
||||
362
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
362
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
@ -0,0 +1,362 @@
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
oauth_config = DatasourceProviderService().get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_config:
|
||||
raise ValueError(f"No OAuth Client Config for {provider_id}")
|
||||
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_config,
|
||||
)
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class DatasourceOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider_id: str):
|
||||
context_id = request.cookies.get("context_id") or request.args.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
oauth_client_params = datasource_provider_service.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_client_params:
|
||||
raise NotFound()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOAuthCallback,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource/<path:provider_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceHardCodeAuthListApi,
|
||||
"/auth/plugin/datasource/default-list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDefaultApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
@ -0,0 +1,57 @@
|
||||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
164
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
Normal file
164
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
Normal file
@ -0,0 +1,164 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
knowledge_pipeline_publish_enabled,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
@ -0,0 +1,114 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=IconInfo(
|
||||
icon="📙",
|
||||
icon_background="#FFF4ED",
|
||||
icon_type="emoji",
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
yaml_content=args["yaml_content"],
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
|
||||
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=IconInfo(
|
||||
icon="📙",
|
||||
icon_background="#FFF4ED",
|
||||
icon_type="emoji",
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
@ -0,0 +1,389 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=pipeline.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(pipeline.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class RagPipelineVariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||
)
|
||||
@ -0,0 +1,149 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
dataset_name=args.get("name"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
46
api/controllers/console/datasets/wraps.py
Normal file
46
api/controllers/console/datasets/wraps.py
Normal file
@ -0,0 +1,46 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Callable | None = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user is not an account")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@ -20,6 +20,7 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.model import AppMode, InstalledApp
|
||||
@ -82,6 +83,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
@ -68,10 +69,11 @@ class FileApi(Resource):
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
@ -92,7 +94,7 @@ class FilePreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService.get_file_preview(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from controllers.common.errors import (
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
|
||||
35
api/controllers/console/spec.py
Normal file
35
api/controllers/console/spec.py
Normal file
@ -0,0 +1,35 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
from libs.login import login_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpecSchemaDefinitionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""
|
||||
Get system JSON Schema definitions specification
|
||||
Used for frontend component type mapping
|
||||
"""
|
||||
try:
|
||||
schema_manager = SchemaManager()
|
||||
schema_definitions = schema_manager.get_all_schema_definitions()
|
||||
return schema_definitions, 200
|
||||
except Exception:
|
||||
logger.exception("Failed to get schema definitions from local registry")
|
||||
# Return empty array as fallback
|
||||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
@ -516,18 +516,20 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
589
api/controllers/console/workspace/trigger_providers.py
Normal file
589
api/controllers/console/workspace/trigger_providers.py
Normal file
@ -0,0 +1,589 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if there's a system OAuth client
|
||||
system_client = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_configured": bool(custom_params),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params or {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@ -227,7 +227,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -279,3 +279,14 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
import services
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_database import db
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
|
||||
@ -28,7 +29,7 @@ class ImagePreviewApi(Resource):
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
@ -57,7 +58,7 @@ class FilePreviewApi(Resource):
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService.get_file_generator_by_file_id(
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
@ -108,7 +109,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
raise NotFound("webapp logo is not found")
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
generator, mimetype = FileService(db.engine).get_public_image_preview(
|
||||
webapp_logo_file_id,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
|
||||
@ -420,7 +420,12 @@ class PluginUploadFileRequestApi(Resource):
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename=payload.filename,
|
||||
mimetype=payload.mimetype,
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
|
||||
@ -32,11 +32,20 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not user_model:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -9,10 +9,9 @@ from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
@ -195,50 +194,6 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
) -> EndUser:
|
||||
"""Create end user in existing session"""
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=mcp_server_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush() # Use flush instead of commit to keep transaction open
|
||||
session.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
def _handle_mcp_request(
|
||||
self,
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
@ -12,8 +12,9 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_model
|
||||
from models.model import App, EndUser
|
||||
from models import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -52,7 +53,7 @@ class FileApi(Resource):
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -26,7 +26,8 @@ from core.errors.error import (
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -124,7 +124,12 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@ -134,6 +139,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
@ -199,7 +207,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@ -301,8 +313,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
@ -390,10 +403,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
@ -571,7 +588,7 @@ class DocumentApi(DatasetApiResource):
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -604,7 +621,7 @@ class DocumentApi(DatasetApiResource):
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
||||
@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException):
|
||||
error_code = "dataset_in_use"
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
code = 409
|
||||
|
||||
|
||||
class PipelineRunError(BaseHTTPException):
|
||||
error_code = "pipeline_run_error"
|
||||
description = "An error occurred while running the pipeline."
|
||||
code = 500
|
||||
|
||||
@ -133,7 +133,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/metadata/built-in")
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_built_in_fields")
|
||||
@service_api_ns.doc(description="Get all built-in metadata fields")
|
||||
@ -143,7 +143,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
@ -0,0 +1,242 @@
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.reqparse import ParseResult, RequestParser
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
|
||||
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource plugins retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"credential_id": "Credential ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource node run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
parser: RequestParser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, location="json")
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
return helper.compact_generate_response(
|
||||
PipelineGenerator.convert_to_event_stream(
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=datasource_node_run_api_entity.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_node_run_api_entity.datasource_type,
|
||||
is_published=datasource_node_run_api_entity.is_published,
|
||||
credential_id=datasource_node_run_api_entity.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"datasource_info_list": "Datasource info list",
|
||||
"start_node_id": "Start node ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Pipeline run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
parser: RequestParser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
parser.add_argument(
|
||||
"response_mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["streaming", "blocking"],
|
||||
default="blocking",
|
||||
location="json",
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
try:
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||
streaming=args.get("response_mode") == "streaming",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except Exception as ex:
|
||||
raise PipelineRunError(description=str(ex))
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/pipeline/file-upload")
|
||||
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
"""Resource for uploading a file to a knowledgebase pipeline."""
|
||||
|
||||
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
|
||||
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - no file or invalid file",
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str):
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
"""
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
@ -193,6 +193,47 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
|
||||
7
api/controllers/trigger/__init__.py
Normal file
7
api/controllers/trigger/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
41
api/controllers/trigger/trigger.py
Normal file
41
api/controllers/trigger/trigger.py
Normal file
@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
raise NotFound("Endpoint not found")
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
46
api/controllers/trigger/webhook.py
Normal file
46
api/controllers/trigger/webhook.py
Normal file
@ -0,0 +1,46 @@
|
||||
import logging
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
# Get webhook trigger, workflow, and node configuration
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||
|
||||
# Extract request data
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate request against node configuration
|
||||
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
|
||||
if not validation_result["valid"]:
|
||||
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@ -11,6 +11,7 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_model
|
||||
from services.file_service import FileService
|
||||
|
||||
@ -68,7 +69,7 @@ class FileApi(WebApiResource):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -14,6 +14,7 @@ from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||
from services.file_service import FileService
|
||||
|
||||
@ -119,7 +120,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
|
||||
@ -21,6 +21,7 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -112,6 +113,11 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
return_resource=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Any
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
|
||||
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
|
||||
hide: bool = False
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@ -129,6 +129,16 @@ class VariableEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: str | None = None
|
||||
placeholder: str | None = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
@ -288,7 +298,7 @@ class AppConfig(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@ -20,3 +22,48 @@ class WorkflowVariablesConfigManager:
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
Convert workflow start variables to variables
|
||||
|
||||
:param workflow: workflow instance
|
||||
"""
|
||||
variables = []
|
||||
|
||||
# get second step node
|
||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||
if not rag_pipeline_variables:
|
||||
return []
|
||||
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
|
||||
|
||||
# get datasource node data
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == start_node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if datasource_node_data:
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
variables_map.pop(last_part, None)
|
||||
if value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
variables_map.pop(last_part, None)
|
||||
|
||||
all_second_step_variables = list(variables_map.values())
|
||||
|
||||
for item in all_second_step_variables:
|
||||
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
|
||||
variables.append(RagPipelineVariableEntity.model_validate(item))
|
||||
|
||||
return variables
|
||||
|
||||
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
@ -420,7 +420,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# get conversation dialogue count
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id)
|
||||
# NOTE: dialogue_count should not start from 0,
|
||||
# because during the first conversation, dialogue_count should be 1.
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@ -467,7 +469,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
from models.workflow import ConversationVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -78,23 +79,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -146,16 +153,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
@ -167,11 +185,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@ -387,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
@ -434,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -751,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
@ -800,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
@ -814,17 +777,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# For unhandled events, we continue (original behavior)
|
||||
return
|
||||
|
||||
@ -848,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
yield from self._handle_text_chunk_event(
|
||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
||||
)
|
||||
|
||||
case QueueErrorEvent():
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
@ -14,6 +14,7 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
)
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -44,9 +45,9 @@ class BaseAppGenerator:
|
||||
mapping=v,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
@ -59,9 +60,9 @@ class BaseAppGenerator:
|
||||
mappings=v,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
@ -182,8 +183,9 @@ class BaseAppGenerator:
|
||||
|
||||
@final
|
||||
@staticmethod
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
assert isinstance(account, Account)
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
@ -200,6 +202,7 @@ class BaseAppGenerator:
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
user=account,
|
||||
)
|
||||
else:
|
||||
|
||||
|
||||
@ -127,6 +127,21 @@ class AppQueueManager:
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag without user permission check.
|
||||
This method allows stopping workflows without user context.
|
||||
|
||||
:param task_id: The task ID to stop
|
||||
:return:
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
|
||||
@ -164,7 +164,9 @@ class ChatAppRunner(AppRunner):
|
||||
config=app_config.dataset,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
show_retrieve_source=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
hit_callback=hit_callback,
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
@ -36,24 +31,23 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
Account,
|
||||
EndUser,
|
||||
)
|
||||
from services.variable_truncator import VariableTruncator
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
@ -65,6 +59,7 @@ class WorkflowResponseConverter:
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
self._truncator = VariableTruncator.default()
|
||||
|
||||
def workflow_start_to_stream_response(
|
||||
self,
|
||||
@ -156,7 +151,8 @@ class WorkflowResponseConverter:
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
@ -171,11 +167,19 @@ class WorkflowResponseConverter:
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
|
||||
return response
|
||||
@ -183,11 +187,7 @@ class WorkflowResponseConverter:
|
||||
def workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
@ -210,9 +210,12 @@ class WorkflowResponseConverter:
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
process_data=workflow_node_execution.get_response_process_data(),
|
||||
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
@ -221,9 +224,6 @@ class WorkflowResponseConverter:
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
@ -255,9 +255,12 @@ class WorkflowResponseConverter:
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
process_data=workflow_node_execution.get_response_process_data(),
|
||||
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
@ -275,50 +278,6 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunStartedEvent,
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
@ -326,6 +285,7 @@ class WorkflowResponseConverter:
|
||||
workflow_execution_id: str,
|
||||
event: QueueIterationStartEvent,
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
@ -333,13 +293,12 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=truncated,
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -357,15 +316,10 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
@ -377,6 +331,11 @@ class WorkflowResponseConverter:
|
||||
event: QueueIterationCompletedEvent,
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
|
||||
json_converter.to_json_encodable(event.outputs) or {}
|
||||
)
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
@ -384,28 +343,29 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=json_converter.to_json_encodable(event.outputs),
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=inputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_loop_start_to_stream_response(
|
||||
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
|
||||
) -> LoopNodeStartStreamResponse:
|
||||
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||
return LoopNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
@ -413,10 +373,11 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=truncated,
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
@ -437,15 +398,16 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
# The `pre_loop_output` field is not utilized by the frontend.
|
||||
# Previously, it was assigned the value of `event.output`.
|
||||
pre_loop_output={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
@ -456,6 +418,11 @@ class WorkflowResponseConverter:
|
||||
workflow_execution_id: str,
|
||||
event: QueueLoopCompletedEvent,
|
||||
) -> LoopNodeCompletedStreamResponse:
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
|
||||
json_converter.to_json_encodable(event.outputs) or {}
|
||||
)
|
||||
return LoopNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
@ -463,17 +430,19 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=inputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
|
||||
@ -124,7 +124,9 @@ class CompletionAppRunner(AppRunner):
|
||||
config=dataset_config,
|
||||
query=query or "",
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source
|
||||
if app_config.additional_features
|
||||
else False,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
|
||||
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
@ -0,0 +1,95 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
66
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
66
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
@ -0,0 +1,66 @@
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Pipeline Config Entity.
|
||||
"""
|
||||
|
||||
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||
pass
|
||||
|
||||
|
||||
class PipelineConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
|
||||
pipeline_config = PipelineConfig(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
workflow_id=workflow.id,
|
||||
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
|
||||
workflow=workflow, start_node_id=start_node_id
|
||||
),
|
||||
)
|
||||
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
:param only_structure_validate: only validate the structure of the config
|
||||
"""
|
||||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
851
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
851
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
@ -0,0 +1,851 @@
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
)
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# Add null check for dataset
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
dataset = pipeline.retrieve_dataset(session)
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||
)
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||
)
|
||||
documents: list[Document] = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
rag_pipeline_invoke_entities = []
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = args.get("original_document_id") or None
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
|
||||
document_id = document_id or documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=json.dumps(datasource_info),
|
||||
datasource_node_id=start_node_id,
|
||||
input_data=inputs,
|
||||
pipeline_id=pipeline.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
db.session.commit()
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=dataset.id,
|
||||
original_document_id=args.get("original_document_id"),
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=pipeline_config.rag_pipeline_variables,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=contextvars.copy_context(),
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
rag_pipeline_invoke_entities.append(
|
||||
RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
"dataset": PipelineDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
).model_dump(),
|
||||
"documents": [
|
||||
PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_type=document.data_source_type,
|
||||
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump()
|
||||
for document in documents
|
||||
],
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param pipeline: Pipeline
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"queue_manager": queue_manager,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
||||
invoke_from,
|
||||
user,
|
||||
)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session)
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
dataset_id=dataset.id,
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session)
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
dataset_id=dataset.id,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Workflow.app_id == application_generate_entity.app_config.app_id,
|
||||
Workflow.id == application_generate_entity.app_config.workflow_id,
|
||||
)
|
||||
)
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.SERVICE_API,
|
||||
}
|
||||
|
||||
if is_external_api_call:
|
||||
# For external API calls, use end user's session ID
|
||||
end_user = session.scalar(
|
||||
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
|
||||
)
|
||||
system_user_id = end_user.session_id if end_user else ""
|
||||
else:
|
||||
# For internal calls, use the original user ID
|
||||
system_user_id = application_generate_entity.user_id
|
||||
# workflow app
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(
|
||||
"Fails to process generate task pipeline, task_id: %r",
|
||||
application_generate_entity.task_id,
|
||||
)
|
||||
raise e
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Union[Account, EndUser],
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info.get("title", "untitled")
|
||||
elif datasource_type == "online_drive":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=position,
|
||||
data_source_type=datasource_type,
|
||||
data_source_info=json.dumps(datasource_info),
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.source: datasource_type,
|
||||
}
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
def _format_datasource_info_list(
|
||||
self,
|
||||
datasource_type: str,
|
||||
datasource_info_list: list[Mapping[str, Any]],
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
start_node_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
all_files: list[Mapping[str, Any]] = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == start_node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if not datasource_node_data:
|
||||
raise ValueError("Datasource node data not found")
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
||||
datasource_name=datasource_node_data.get("datasource_name"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
credential_id=datasource_node_data.get("credential_id"),
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
if datasource_info.get("id") and datasource_info.get("type") == "folder":
|
||||
# get all files in the folder
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime,
|
||||
datasource_info.get("id", ""),
|
||||
datasource_info.get("bucket", None),
|
||||
user.id,
|
||||
all_files,
|
||||
datasource_info,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
all_files.append(
|
||||
{
|
||||
"id": datasource_info.get("id", ""),
|
||||
"name": datasource_info.get("name", "untitled"),
|
||||
"bucket": datasource_info.get("bucket", None),
|
||||
}
|
||||
)
|
||||
return all_files
|
||||
else:
|
||||
return datasource_info_list
|
||||
|
||||
def _get_files_in_folder(
|
||||
self,
|
||||
datasource_runtime: OnlineDriveDatasourcePlugin,
|
||||
prefix: str,
|
||||
bucket: str | None,
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
"""
|
||||
result_generator = datasource_runtime.online_drive_browse_files(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveBrowseFilesRequest(
|
||||
bucket=bucket,
|
||||
prefix=prefix,
|
||||
max_keys=20,
|
||||
next_page_parameters=next_page_parameters,
|
||||
),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
is_truncated = False
|
||||
for result in result_generator:
|
||||
for files in result.result:
|
||||
for file in files.files:
|
||||
if file.type == "folder":
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime,
|
||||
file.id,
|
||||
bucket,
|
||||
user_id,
|
||||
all_files,
|
||||
datasource_info,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
all_files.append(
|
||||
{
|
||||
"id": file.id,
|
||||
"name": file.name,
|
||||
"bucket": bucket,
|
||||
}
|
||||
)
|
||||
is_truncated = files.is_truncated
|
||||
next_page_parameters = files.next_page_parameters
|
||||
|
||||
if is_truncated:
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
|
||||
)
|
||||
45
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
45
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
@ -0,0 +1,45 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class PipelineQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
280
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
280
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
@ -0,0 +1,280 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Pipeline Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(PipelineConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
document_id=self.application_generate_entity.document_id,
|
||||
original_document_id=self.application_generate_entity.original_document_id,
|
||||
batch=self.application_generate_entity.batch,
|
||||
dataset_id=self.application_generate_entity.dataset_id,
|
||||
datasource_type=self.application_generate_entity.datasource_type,
|
||||
datasource_info=self.application_generate_entity.datasource_info,
|
||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||
)
|
||||
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
) and rag_pipeline_variable.variable in inputs:
|
||||
rag_pipeline_variables.append(
|
||||
RAGPipelineVariableInput(
|
||||
variable=rag_pipeline_variable,
|
||||
value=inputs[rag_pipeline_variable.variable],
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._update_document_status(
|
||||
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
|
||||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
|
||||
) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
graph_config = workflow.graph_dict
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# nodes = graph_config.get("nodes", [])
|
||||
# edges = graph_config.get("edges", [])
|
||||
# real_run_nodes = []
|
||||
# real_edges = []
|
||||
# exclude_node_ids = []
|
||||
# for node in nodes:
|
||||
# node_id = node.get("id")
|
||||
# node_type = node.get("data", {}).get("type", "")
|
||||
# if node_type == "datasource":
|
||||
# if start_node_id != node_id:
|
||||
# exclude_node_ids.append(node_id)
|
||||
# continue
|
||||
# real_run_nodes.append(node)
|
||||
|
||||
# for edge in edges:
|
||||
# if edge.get("source") in exclude_node_ids:
|
||||
# continue
|
||||
# real_edges.append(edge)
|
||||
# graph_config = dict(graph_config)
|
||||
# graph_config["nodes"] = real_run_nodes
|
||||
# graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
|
||||
"""
|
||||
Update document status
|
||||
"""
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
if document_id and dataset_id:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -53,7 +53,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -67,7 +68,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -81,7 +83,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@ -94,7 +97,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -123,24 +127,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||
# start node get inputs
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@ -159,7 +165,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@ -186,7 +195,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -200,8 +209,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -214,7 +223,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
@ -237,16 +245,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
||||
invoke_from,
|
||||
)
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
@ -434,14 +440,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
):
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -474,10 +479,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,9 +32,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -41,9 +42,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@ -52,24 +53,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -92,15 +99,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
root_node_id=self._root_node_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
@ -112,10 +132,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _dispatch_event(
|
||||
self,
|
||||
event: Any,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle workflow failed and stop events with isinstance check
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeInLoopFailedEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@ -79,7 +72,15 @@ class WorkflowBasedAppRunner:
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
@ -91,8 +92,28 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -104,6 +125,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
@ -140,13 +162,30 @@ class WorkflowBasedAppRunner:
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -201,6 +240,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
@ -237,13 +277,30 @@ class WorkflowBasedAppRunner:
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -310,39 +367,32 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
@ -350,44 +400,29 @@ class WorkflowBasedAppRunner:
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
start_at=event.start_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -396,34 +431,18 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
@ -434,93 +453,21 @@ class WorkflowBasedAppRunner:
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInLoopFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInLoopFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
@ -533,10 +480,10 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
@ -547,51 +494,13 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
@ -599,55 +508,41 @@ class WorkflowBasedAppRunner:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
@ -655,42 +550,32 @@ class WorkflowBasedAppRunner:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
@ -35,6 +38,7 @@ class InvokeFrom(StrEnum):
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@ -113,8 +117,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
# Using Any to avoid circular import with TraceQueueManager
|
||||
trace_manager: Any | None = None
|
||||
trace_manager: Optional["TraceQueueManager"] = None
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
@ -240,3 +243,34 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: str | None = None
|
||||
original_document_id: str | None = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
# Rebuild models that use forward references
|
||||
AppGenerateEntity.model_rebuild()
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild()
|
||||
ConversationAppGenerateEntity.model_rebuild()
|
||||
ChatAppGenerateEntity.model_rebuild()
|
||||
CompletionAppGenerateEntity.model_rebuild()
|
||||
AgentChatAppGenerateEntity.model_rebuild()
|
||||
AdvancedChatAppGenerateEntity.model_rebuild()
|
||||
WorkflowAppGenerateEntity.model_rebuild()
|
||||
RagPipelineGenerateEntity.model_rebuild()
|
||||
|
||||
@ -3,15 +3,13 @@ from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_title: str
|
||||
node_run_index: int
|
||||
output: Any | None = None # output for the current iteration
|
||||
duration: float | None = None
|
||||
output: Any = None # output for the current iteration
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
@ -134,21 +110,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
error: str | None = None
|
||||
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -175,9 +143,9 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueLoopNextEvent(AppQueueEvent):
|
||||
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -203,8 +171,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Any | None = None # output for the current loop
|
||||
duration: float | None = None
|
||||
output: Any = None # output for the current loop
|
||||
|
||||
|
||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -229,9 +196,9 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
error: str | None = None
|
||||
@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: dict[str, Any] | None = None
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||
exceptions_count: int
|
||||
outputs: dict[str, Any] | None = None
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueNodeStartedEvent(AppQueueEvent):
|
||||
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
predecessor_node_id: str | None = None
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -411,16 +374,12 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str | None = None
|
||||
"""single iteration duration map"""
|
||||
iteration_duration_map: dict[str, float] | None = None
|
||||
"""single loop duration map"""
|
||||
loop_duration_map: dict[str, float] | None = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
@ -436,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||
error: str | None = None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
node_id: str
|
||||
|
||||
|
||||
@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInLoopFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -545,9 +437,9 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
@ -563,24 +455,16 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ERROR
|
||||
error: Any | None = None
|
||||
error: Any = None
|
||||
|
||||
|
||||
class QueuePingEvent(AppQueueEvent):
|
||||
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
error: str
|
||||
|
||||
14
api/core/app/entities/rag_pipeline_invoke_entities.py
Normal file
14
api/core/app/entities/rag_pipeline_invoke_entities.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagPipelineInvokeEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
application_generate_entity: dict[str, Any]
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
workflow_id: str
|
||||
streaming: bool
|
||||
workflow_execution_id: str | None = None
|
||||
workflow_thread_pool_id: str | None = None
|
||||
@ -1,13 +1,13 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@ -55,32 +55,30 @@ class StreamEvent(StrEnum):
|
||||
Stream event
|
||||
"""
|
||||
|
||||
PING = auto()
|
||||
ERROR = auto()
|
||||
MESSAGE = auto()
|
||||
MESSAGE_END = auto()
|
||||
TTS_MESSAGE = auto()
|
||||
TTS_MESSAGE_END = auto()
|
||||
MESSAGE_FILE = auto()
|
||||
MESSAGE_REPLACE = auto()
|
||||
AGENT_THOUGHT = auto()
|
||||
AGENT_MESSAGE = auto()
|
||||
WORKFLOW_STARTED = auto()
|
||||
WORKFLOW_FINISHED = auto()
|
||||
NODE_STARTED = auto()
|
||||
NODE_FINISHED = auto()
|
||||
NODE_RETRY = auto()
|
||||
PARALLEL_BRANCH_STARTED = auto()
|
||||
PARALLEL_BRANCH_FINISHED = auto()
|
||||
ITERATION_STARTED = auto()
|
||||
ITERATION_NEXT = auto()
|
||||
ITERATION_COMPLETED = auto()
|
||||
LOOP_STARTED = auto()
|
||||
LOOP_NEXT = auto()
|
||||
LOOP_COMPLETED = auto()
|
||||
TEXT_CHUNK = auto()
|
||||
TEXT_REPLACE = auto()
|
||||
AGENT_LOG = auto()
|
||||
PING = "ping"
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
LOOP_STARTED = "loop_started"
|
||||
LOOP_NEXT = "loop_next"
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@ -138,7 +136,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
files: Sequence[Mapping[str, Any]] | None = None
|
||||
|
||||
|
||||
@ -175,7 +173,7 @@ class AgentThoughtStreamResponse(StreamResponse):
|
||||
thought: str | None = None
|
||||
observation: str | None = None
|
||||
tool: str | None = None
|
||||
tool_labels: dict | None = None
|
||||
tool_labels: Mapping[str, object] = Field(default_factory=dict)
|
||||
tool_input: str | None = None
|
||||
message_files: list[str] | None = None
|
||||
|
||||
@ -228,7 +226,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_by: dict | None = None
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int
|
||||
exceptions_count: int | None = 0
|
||||
@ -256,8 +254,9 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
@ -313,8 +312,11 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
@ -382,8 +384,11 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
@ -436,54 +441,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
status: str
|
||||
error: str | None = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
@ -502,8 +459,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
inputs_truncated: bool = False
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
@ -526,12 +482,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Any | None = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
duration: float | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -553,18 +504,18 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Mapping | None = None
|
||||
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
@ -589,6 +540,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
@ -613,12 +565,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_loop_output: Any | None = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
pre_loop_output: Any = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
duration: float | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
@ -640,14 +591,16 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Mapping | None = None
|
||||
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: str | None = None
|
||||
@ -757,7 +710,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
answer: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
|
||||
data: Data
|
||||
@ -777,7 +730,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
mode: str
|
||||
message_id: str
|
||||
answer: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
|
||||
data: Data
|
||||
@ -825,7 +778,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
error: str | None = None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
node_id: str
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
|
||||
@ -138,6 +138,8 @@ class MessageCycleManager:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
|
||||
|
||||
@ -109,7 +109,9 @@ class AppGeneratorTTSPublisher:
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
output = message.event.outputs.get("output", "")
|
||||
if isinstance(output, str):
|
||||
self.msg_text += output
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||
@ -119,7 +121,7 @@ class AppGeneratorTTSPublisher:
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
if isinstance(text_tmp, str):
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user