mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 03:59:30 +08:00
Compare commits
218 Commits
fix/versio
...
e-0154
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f63cd52a2 | |||
| 5b357fdbf0 | |||
| 9283a5414f | |||
| 8923e64b8d | |||
| 2a2a0e9be9 | |||
| 061a765b7d | |||
| acd7fead87 | |||
| 64e9d96d84 | |||
| d27de3818c | |||
| bbb080d5b2 | |||
| 8c025abb3b | |||
| c01d8a70f3 | |||
| 98606ca558 | |||
| adf3e18ebd | |||
| 1ca15989e0 | |||
| 8b5a3a9424 | |||
| 42ddcf1edd | |||
| 21561df10f | |||
| 4327ec8c4c | |||
| bbc5ec8301 | |||
| 4a51a72c1d | |||
| 4b6adffa8e | |||
| c7fd73d330 | |||
| 8a709e445a | |||
| f02b77b99f | |||
| abc625bcce | |||
| b6bc1f8bc4 | |||
| b8f9037cd3 | |||
| 02606ba3c7 | |||
| 79311d3fb5 | |||
| 31086a1fbf | |||
| 6ae5d052e5 | |||
| c794ecf101 | |||
| d887aae012 | |||
| 1b1e96eff7 | |||
| eecd091063 | |||
| d38f2cb380 | |||
| 56aaee5558 | |||
| d72b4752c9 | |||
| ea769c6483 | |||
| ec194fa3d4 | |||
| b877039859 | |||
| 54634f26d2 | |||
| 3bef91a2cd | |||
| 7da45ba589 | |||
| e0232c67cc | |||
| 1dc4a229d4 | |||
| 0e0bada1f3 | |||
| 5366a814f9 | |||
| f1240a22db | |||
| 66f35c2b7e | |||
| 766ee48531 | |||
| 083045f45c | |||
| fe237802c9 | |||
| 00b923651f | |||
| 24fce3cc64 | |||
| 8ba969f67d | |||
| 6844d59371 | |||
| fe5529db85 | |||
| d89034d913 | |||
| 360fbeb108 | |||
| e7c2fa1cfa | |||
| 735f09d977 | |||
| f83a5e3e49 | |||
| 01a8d4efcc | |||
| fdb1e649d4 | |||
| 0856792a57 | |||
| 0e33a3aa5f | |||
| d3895bcd6b | |||
| eeb390650b | |||
| ca19bd31d4 | |||
| 413dfd5628 | |||
| f9515901cc | |||
| 3f42fabff8 | |||
| 1caa578771 | |||
| b7c11c1818 | |||
| 3eb3db0663 | |||
| be46f32056 | |||
| 6e5c915f96 | |||
| 04d13a8116 | |||
| e638ede3f2 | |||
| 2348abe4bf | |||
| f7e7a399d9 | |||
| ba91f34636 | |||
| 16865d43a8 | |||
| 0d13aee15c | |||
| 49b4144ffd | |||
| 186e2d972e | |||
| 40dd63ecef | |||
| 6d66d6da15 | |||
| 03ec3513f3 | |||
| 87763fc234 | |||
| f6c44cae2e | |||
| da2ee04fce | |||
| 7673c36af3 | |||
| 9457b2af2f | |||
| 7203991032 | |||
| 5a685f7156 | |||
| a6a25030ad | |||
| 00458a31d5 | |||
| c6ddf6d6cc | |||
| 34b21b3065 | |||
| 8fbb355cd2 | |||
| e8b3b7e578 | |||
| 59ca44f493 | |||
| 9e1457c2c3 | |||
| fac83e14bc | |||
| a97cec57e4 | |||
| 38c10b47d3 | |||
| 1a2523fd15 | |||
| 03243cb422 | |||
| 2ad7ee0344 | |||
| 55ce3618ce | |||
| e9e34c1ab2 | |||
| d4c916b496 | |||
| 8fbc9c9342 | |||
| 1b6fd9dfe8 | |||
| 304467e3f5 | |||
| 7452032d81 | |||
| 87e2048f1b | |||
| d876084392 | |||
| 840729afa5 | |||
| 941ad03f3c | |||
| d73d191f99 | |||
| c2664e0283 | |||
| ee61cede4e | |||
| b47669b80b | |||
| c0d0c63592 | |||
| b09c39c8dc | |||
| b4b09ddc3c | |||
| d0a21086bd | |||
| d44882c1b5 | |||
| 23c68efa2d | |||
| 560c5de1b7 | |||
| 5d91dbd000 | |||
| 6c31ee36cd | |||
| edc29780ed | |||
| aad7e4dd1c | |||
| a6a727e8a4 | |||
| d1fc65fabc | |||
| d4be5ef9de | |||
| 1374be5a31 | |||
| b2bbc28580 | |||
| 59b3e672aa | |||
| a2f8bce8f5 | |||
| a2b9adb3a2 | |||
| 28067640b5 | |||
| da67916843 | |||
| e54ce479ad | |||
| 6024d8a42d | |||
| f565f08aa0 | |||
| fd4afe09f8 | |||
| dd0904f95c | |||
| 4c3076f2a4 | |||
| 1e73f63ff8 | |||
| d167d5b1be | |||
| 71fa14f791 | |||
| 8dd1873e76 | |||
| f91f5c7401 | |||
| c62b7cc679 | |||
| 3ee213ddca | |||
| 8429877b02 | |||
| 05a0faff6a | |||
| e09f6e4987 | |||
| e23f4b0265 | |||
| f582d4a13e | |||
| 2f41bd495d | |||
| 162a8c4393 | |||
| 3d1ce4c53f | |||
| 6db3ae9b8e | |||
| 6d0cb9dc33 | |||
| 46e95e8309 | |||
| a7b9375877 | |||
| 0c6a8a130e | |||
| 9903f1e703 | |||
| 6fad719e42 | |||
| 9aaee8ee47 | |||
| 166221d784 | |||
| 925d69a2ee | |||
| 5ff08e241a | |||
| 3defd24087 | |||
| 9d86147d20 | |||
| 80801ac4ab | |||
| 210926cd91 | |||
| 677a69deed | |||
| 8dfdee21ce | |||
| 6ea77ab4cd | |||
| e3c996688d | |||
| bc3a570dda | |||
| 0800021a2d | |||
| 435eddd867 | |||
| 6e0fb055d1 | |||
| 1e9ac7ffeb | |||
| b4873ecb43 | |||
| 1859d57784 | |||
| 69d58fbb50 | |||
| cb34991663 | |||
| c700364e1c | |||
| 9a6b1dc3a1 | |||
| 54b5b80a07 | |||
| 831459b895 | |||
| 4e101604c3 | |||
| a6455269f0 | |||
| cd257b91c5 | |||
| d8f57bf899 | |||
| 989fb11fd7 | |||
| 140965b738 | |||
| 14ee51aead | |||
| 2e97ba5700 | |||
| f549d53b68 | |||
| a085ad4719 | |||
| f230a9232e | |||
| e84bf35e2a | |||
| 20f090537f | |||
| dbe7a7c4fd | |||
| b7a4e3903e | |||
| b4c1c2f731 | |||
| 1b940e7daa |
@ -1,12 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@9.12.2
|
||||
cd web && pnpm install
|
||||
cd web && npm install
|
||||
pipx install poetry
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
|
||||
|
||||
7
.github/workflows/api-tests.yml
vendored
7
.github/workflows/api-tests.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
@ -48,9 +47,15 @@ jobs:
|
||||
- name: Run Unit tests
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -P api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@ -5,7 +5,8 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "plugins/beta"
|
||||
- "deploy/enterprise"
|
||||
- "e-0154"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
|
||||
1
.github/workflows/db-migration-test.yml
vendored
1
.github/workflows/db-migration-test.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
|
||||
29
.github/workflows/deploy-enterprise.yml
vendored
Normal file
29
.github/workflows/deploy-enterprise.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: Deploy Enterprise
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/enterprise"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
|
||||
username: ${{ secrets.ENTERPRISE_SSH_USER }}
|
||||
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
|
||||
script: |
|
||||
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}
|
||||
13
.github/workflows/style.yml
vendored
13
.github/workflows/style.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@ -67,23 +66,17 @@ jobs:
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: pnpm
|
||||
cache: yarn
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@ -141,7 +134,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: plugins/beta
|
||||
DEFAULT_BRANCH: main
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
6
.github/workflows/tool-test-sdks.yaml
vendored
6
.github/workflows/tool-test-sdks.yaml
vendored
@ -32,10 +32,10 @@ jobs:
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pnpm install
|
||||
run: yarn install
|
||||
|
||||
- name: Test
|
||||
run: pnpm test
|
||||
run: yarn test
|
||||
|
||||
@ -38,11 +38,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm run auto-gen-i18n
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -34,13 +34,13 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: pnpm
|
||||
cache: yarn
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm test
|
||||
run: yarn test
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -175,7 +175,6 @@ docker/volumes/pgvector/data/*
|
||||
docker/volumes/pgvecto_rs/data/*
|
||||
docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@ -194,9 +193,3 @@ api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
.env
|
||||
*.env.*
|
||||
|
||||
storage/generate_files/*
|
||||
storage/privkeys/*
|
||||
storage/tools/*
|
||||
storage/upload_files/*
|
||||
|
||||
# Logs
|
||||
logs
|
||||
@ -12,8 +9,6 @@ logs
|
||||
|
||||
# jetbrains
|
||||
.idea
|
||||
.mypy_cache
|
||||
.ruff_cache
|
||||
|
||||
# venv
|
||||
.venv
|
||||
@ -409,6 +409,7 @@ MAX_VARIABLE_SIZE=204800
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
@ -421,22 +422,6 @@ POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ RUN \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
@ -71,10 +71,6 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
|
||||
|
||||
@ -25,8 +25,6 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@ -526,7 +524,7 @@ def add_qdrant_doc_id_index(field: str):
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
|
||||
|
||||
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
|
||||
@ -595,7 +593,7 @@ def upgrade_db():
|
||||
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception("Failed to execute database migration")
|
||||
finally:
|
||||
lock.release()
|
||||
@ -641,7 +639,7 @@ where sites.id is null limit 1000"""
|
||||
account = accounts[0]
|
||||
print("Fixing missing site for app {}".format(app.id))
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
|
||||
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
|
||||
@ -651,68 +649,3 @@ where sites.id is null limit 1000"""
|
||||
break
|
||||
|
||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
|
||||
|
||||
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
|
||||
def migrate_data_for_plugin():
|
||||
"""
|
||||
Migrate data for plugin.
|
||||
"""
|
||||
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
|
||||
|
||||
PluginDataMigration.migrate()
|
||||
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-plugins", help="Extract plugins.")
|
||||
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
|
||||
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
|
||||
def extract_plugins(output_file: str, workers: int):
|
||||
"""
|
||||
Extract plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_plugins(output_file, workers)
|
||||
|
||||
click.echo(click.style("Extract plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
|
||||
@click.option(
|
||||
"--output_file",
|
||||
prompt=True,
|
||||
help="The file to store the extracted unique identifiers.",
|
||||
default="unique_identifiers.json",
|
||||
)
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
def extract_unique_plugins(output_file: str, input_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract unique plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Extract unique plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-plugins", help="Install plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
def install_plugins(input_file: str, output_file: str):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting install plugins.", fg="white"))
|
||||
|
||||
PluginMigration.install_plugins(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Install plugins completed.", fg="green"))
|
||||
|
||||
@ -134,60 +134,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
"""
|
||||
|
||||
PLUGIN_DAEMON_URL: HttpUrl = Field(
|
||||
description="Plugin API URL",
|
||||
default="http://localhost:5002",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_KEY: str = Field(
|
||||
description="Plugin API key",
|
||||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
description="Plugin Remote Install Host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
|
||||
description="Plugin Remote Install Port",
|
||||
default=5003,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for plugin packages in bytes",
|
||||
default=15728640,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for plugin bundles in bytes",
|
||||
default=15728640 * 12,
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
"""
|
||||
|
||||
MARKETPLACE_ENABLED: bool = Field(
|
||||
description="Enable or disable marketplace",
|
||||
default=True,
|
||||
)
|
||||
|
||||
MARKETPLACE_API_URL: HttpUrl = Field(
|
||||
description="Marketplace API URL",
|
||||
default="https://marketplace.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -214,10 +160,6 @@ class EndpointConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
ENDPOINT_URL_TEMPLATE: str = Field(
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -556,11 +498,6 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -851,8 +788,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.0.0",
|
||||
default="0.15.4",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -1,19 +1,9 @@
|
||||
from contextvars import ContextVar
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
|
||||
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
from .app.app_import import AppImportApi, AppImportConfirmApi
|
||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from .explore.conversation import (
|
||||
@ -40,7 +40,6 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
@ -167,15 +166,4 @@ api.add_resource(
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
workspace,
|
||||
)
|
||||
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@ -2,8 +2,6 @@ from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -56,8 +54,7 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
if not app:
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
@ -73,10 +70,7 @@ class InsertExploreAppListApi(Resource):
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
@ -116,27 +110,17 @@ class InsertExploreAppApi(Resource):
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).filter(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
).all()
|
||||
installed_apps = InstalledApp.query.filter(
|
||||
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
|
||||
@ -3,8 +3,6 @@ from typing import Any
|
||||
import flask_restful # type: ignore
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -28,16 +26,7 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||
|
||||
if resource is None:
|
||||
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
|
||||
@ -5,16 +5,14 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from fields.app_fields import app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
|
||||
|
||||
@ -90,20 +88,3 @@ class AppImportConfirmApi(Resource):
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -2,7 +2,6 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@ -51,37 +50,33 @@ class AppSite(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
site = session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
session.commit()
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
@ -97,9 +96,6 @@ class DraftWorkflowApi(Resource):
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -143,9 +139,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
@ -167,7 +160,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -185,9 +178,6 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -204,7 +194,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -222,9 +212,6 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -241,7 +228,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -259,9 +246,6 @@ class DraftWorkflowRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
@ -310,20 +294,13 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
@ -362,9 +339,6 @@ class PublishedWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
|
||||
@ -402,17 +376,12 @@ class DefaultBlockConfigApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
if args.get("q"):
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
except json.JSONDecodeError:
|
||||
@ -438,9 +407,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
@ -59,9 +59,3 @@ class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_account_deletion_rate_limit_exceeded"
|
||||
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailPasswordResetLimitError(BaseHTTPException):
|
||||
error_code = "email_password_reset_limit"
|
||||
description = "Too many failed password reset attempts. Please try again in 24 hours."
|
||||
code = 429
|
||||
|
||||
@ -3,18 +3,10 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
@ -45,8 +37,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = Account.query.filter_by(email=args["email"]).first()
|
||||
token = None
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
@ -71,10 +62,6 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
@ -83,10 +70,8 @@ class ForgotPasswordCheckApi(Resource):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email")}
|
||||
|
||||
|
||||
@ -119,8 +104,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
if account:
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
@ -141,7 +125,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -5,8 +5,6 @@ from typing import Optional
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -137,8 +135,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
account = Account.query.filter_by(email=user_info.email).first()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -4,8 +4,6 @@ import json
|
||||
from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@ -78,10 +76,7 @@ class DataSourceApi(Resource):
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
).scalar_one_or_none()
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
@ -113,53 +108,47 @@ class DataSourceNotionListApi(Resource):
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
|
||||
documents = session.execute(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
)
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
@ -169,17 +158,14 @@ class DataSourceNotionApi(Resource):
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from controllers.console.wraps import account_initialization_required, enterpris
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -73,9 +72,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality":
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
@ -623,7 +620,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
|
||||
@ -7,6 +7,7 @@ from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import asc, desc
|
||||
from transformers.hf_argparser import string_to_bool # type: ignore
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -39,7 +40,6 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -150,20 +150,8 @@ class DatasetDocumentListApi(Resource):
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
|
||||
fetch = True
|
||||
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
|
||||
fetch = False
|
||||
else:
|
||||
raise ArgumentTypeError(
|
||||
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
|
||||
f"(case insensitive)."
|
||||
)
|
||||
except (ArgumentTypeError, ValueError, Exception):
|
||||
fetch = string_to_bool(request.args.get("fetch", default="false"))
|
||||
except (ArgumentTypeError, ValueError, Exception) as e:
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -441,8 +429,6 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except PluginDaemonClientSideError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
@ -543,8 +529,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except PluginDaemonClientSideError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
|
||||
@ -2,11 +2,8 @@ import os
|
||||
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
@ -45,11 +42,7 @@ class InitValidateAPI(Resource):
|
||||
def get_init_validate_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
if session.get("is_init_validated"):
|
||||
return True
|
||||
|
||||
with Session(db.engine) as db_session:
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
return session.get("is_init_validated") or DifySetup.query.first()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
@ -52,9 +52,8 @@ class SetupApi(Resource):
|
||||
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
else:
|
||||
return True
|
||||
return DifySetup.query.first()
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(SetupApi, "/setup")
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission)
|
||||
.filter(
|
||||
TenantPluginPermission.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission:
|
||||
# no permission set, allow access for everyone
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
@ -1,36 +0,0 @@
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
|
||||
|
||||
class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
user = current_user
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
||||
|
||||
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
|
||||
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")
|
||||
@ -1,205 +0,0 @@
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
class EndpointCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
parser.add_argument("settings", type=dict, required=True)
|
||||
parser.add_argument("name", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class EndpointDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
parser.add_argument("settings", type=dict, required=True)
|
||||
parser.add_argument("name", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
|
||||
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
|
||||
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
|
||||
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
|
||||
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
|
||||
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
|
||||
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")
|
||||
@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
# Load Balancing Config
|
||||
api.add_resource(
|
||||
LoadBalancingCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
|
||||
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
LoadBalancingConfigCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
)
|
||||
|
||||
@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource):
|
||||
response = {"result": "success" if result else "error"}
|
||||
|
||||
if not result:
|
||||
response["error"] = error or "Unknown error"
|
||||
response["error"] = error
|
||||
|
||||
return response
|
||||
|
||||
@ -125,10 +125,9 @@ class ModelProviderIconApi(Resource):
|
||||
Get model provider icon
|
||||
"""
|
||||
|
||||
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
|
||||
def get(self, provider: str, icon_type: str, lang: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
icon_type=icon_type,
|
||||
lang=lang,
|
||||
@ -184,17 +183,53 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
return data
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
|
||||
)
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi,
|
||||
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
|
||||
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
|
||||
)
|
||||
|
||||
@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource):
|
||||
response = {"result": "success" if result else "error"}
|
||||
|
||||
if not result:
|
||||
response["error"] = error or ""
|
||||
response["error"] = error
|
||||
|
||||
return response
|
||||
|
||||
@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
|
||||
api.add_resource(
|
||||
ModelProviderModelEnableApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable",
|
||||
"/workspaces/current/model-providers/<string:provider>/models/enable",
|
||||
endpoint="model-provider-model-enable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelDisableApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable",
|
||||
"/workspaces/current/model-providers/<string:provider>/models/disable",
|
||||
endpoint="model-provider-model-disable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
|
||||
)
|
||||
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
||||
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")
|
||||
|
||||
@ -1,475 +0,0 @@
|
||||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from models.account import TenantPluginPermission
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
|
||||
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
|
||||
}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
try:
|
||||
plugins = PluginService.list(tenant_id)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
class PluginIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
req.add_argument("filename", type=str, required=True, location="args")
|
||||
args = req.parse_args()
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
file = request.files["pkg"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_pkg(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
file = request.files["bundle"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_bundle(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_plugin_manifest(
|
||||
tenant_id, args["plugin_unique_identifier"]
|
||||
).model_dump()
|
||||
}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_github(
|
||||
tenant_id,
|
||||
args["original_plugin_unique_identifier"],
|
||||
args["new_plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("install_permission", type=str, required=True, location="json")
|
||||
req.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
||||
class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"install_permission": permission.install_permission,
|
||||
"debug_permission": permission.debug_permission,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
|
||||
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
|
||||
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
|
||||
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
|
||||
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
|
||||
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
|
||||
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
|
||||
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
|
||||
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
|
||||
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
|
||||
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
|
||||
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
|
||||
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
|
||||
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||
@ -25,10 +25,8 @@ class ToolProviderListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument(
|
||||
@ -49,43 +47,28 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
|
||||
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
@ -99,13 +82,11 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -150,13 +131,11 @@ class ToolApiProviderAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -189,11 +168,6 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
@ -201,8 +175,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
user_id,
|
||||
tenant_id,
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args["url"],
|
||||
)
|
||||
|
||||
@ -212,10 +186,8 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -237,13 +209,11 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -278,13 +248,11 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -304,10 +272,8 @@ class ToolApiProviderGetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -327,11 +293,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@ -382,13 +344,11 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -421,13 +381,11 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -463,13 +421,11 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -488,10 +444,8 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
@ -522,10 +476,8 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||
@ -546,10 +498,8 @@ class ToolBuiltinListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -567,10 +517,8 @@ class ToolApiListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -588,10 +536,8 @@ class ToolWorkflowListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -617,18 +563,16 @@ class ToolLabelsApi(Resource):
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
|
||||
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
|
||||
|
||||
# api tool provider
|
||||
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||
|
||||
@ -7,7 +7,6 @@ from flask_login import current_user # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
@ -40,6 +39,17 @@ def only_edition_cloud(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_enterprise_edition(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -135,13 +145,9 @@ def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -6,4 +6,4 @@ bp = Blueprint("files", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview, tool_files, upload
|
||||
from . import image_preview, tool_files
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from controllers.service_api.app.error import FileTooLargeError
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class PluginUploadFileApi(Resource):
|
||||
@setup_required
|
||||
@marshal_with(file_fields)
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
tenant_id = request.args.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
user = get_user(tenant_id, user_id)
|
||||
|
||||
filename = file.filename
|
||||
mimetype = file.mimetype
|
||||
|
||||
if not filename or not mimetype:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
if not verify_plugin_file_signature(
|
||||
filename=filename,
|
||||
mimetype=mimetype,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=file.read(),
|
||||
mimetype=mimetype,
|
||||
user=user,
|
||||
source=None,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
||||
@ -5,5 +5,4 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
|
||||
@ -1,293 +0,0 @@
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeParameterExtractorNode,
|
||||
RequestInvokeQuestionClassifierNode,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
RequestInvokeSummary,
|
||||
RequestInvokeTextEmbedding,
|
||||
RequestInvokeTool,
|
||||
RequestInvokeTTS,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import compact_generate_response
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_text_embedding(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeRerank)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_rerank(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeTTSApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTTS)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_tts(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_speech2text(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeModerationApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeModeration)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_moderation(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
),
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
parameters=payload.parameters,
|
||||
model_config=payload.model,
|
||||
instruction=payload.instruction,
|
||||
query=payload.query,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
query=payload.query,
|
||||
model_config=payload.model,
|
||||
classes=payload.classes,
|
||||
instruction=payload.instruction,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||
"""
|
||||
encrypt or decrypt data
|
||||
"""
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
|
||||
).model_dump()
|
||||
except Exception as e:
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
class PluginInvokeSummaryApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSummary)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data={
|
||||
"summary": PluginModelBackwardsInvocation.invoke_summary(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
}
|
||||
).model_dump()
|
||||
except Exception as e:
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
class PluginUploadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||
@ -1,116 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
if user_id == "DEFAULT-USER":
|
||||
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user_model:
|
||||
raise ValueError("user not found")
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant(view: Optional[Callable] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
# fetch json body
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||
parser.add_argument("user_id", type=str, required=True, location="json")
|
||||
|
||||
kwargs = parser.parse_args()
|
||||
|
||||
user_id = kwargs.get("user_id")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
del kwargs["tenant_id"]
|
||||
del kwargs["user_id"]
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.filter(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func):
|
||||
def decorated_view(*args, **kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type(**data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
@ -12,7 +12,7 @@ from services.account_service import TenantService
|
||||
|
||||
class EnterpriseWorkspace(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
@ -33,7 +33,7 @@ class EnterpriseWorkspace(Resource):
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
@ -10,7 +10,7 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
@ -18,7 +18,7 @@ def enterprise_inner_api_only(view):
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
abort(401)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@ -26,7 +26,7 @@ def enterprise_inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth(view):
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
@ -60,19 +60,3 @@ def enterprise_inner_api_user_auth(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -31,16 +32,19 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -58,9 +62,11 @@ class BaseAgentRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@ -73,6 +79,8 @@ class BaseAgentRunner(AppRunner):
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
@ -133,10 +141,11 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
assert tool_entity.entity.description
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.entity.description.llm,
|
||||
description=tool_entity.description.llm if tool_entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -144,7 +153,7 @@ class BaseAgentRunner(AppRunner):
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
@ -177,11 +186,9 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
assert tool.entity.description
|
||||
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.entity.identity.name,
|
||||
description=tool.entity.description.llm,
|
||||
name=tool.identity.name if tool.identity else "unknown",
|
||||
description=tool.description.llm if tool.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -227,7 +234,8 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
if dataset_tool.identity is not None:
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@ -312,24 +320,24 @@ class BaseAgentRunner(AppRunner):
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str | None,
|
||||
tool_input: Union[str, dict, None],
|
||||
thought: str | None,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str | None,
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
updated_agent_thought = (
|
||||
queried_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
agent_thought = updated_agent_thought
|
||||
if not queried_thought:
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
@ -341,39 +349,39 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
updated_agent_thought.tool_input = tool_input
|
||||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
updated_agent_thought.observation = observation
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
||||
updated_agent_thought.total_price = llm_usage.total_price
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = updated_agent_thought.tool_labels or {}
|
||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@ -384,20 +392,42 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
try:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
queried_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not queried_variables:
|
||||
return
|
||||
|
||||
db_variables = queried_variables
|
||||
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
@ -434,11 +464,11 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
|
||||
for tool in tools:
|
||||
@ -485,11 +515,7 @@ class BaseAgentRunner(AppRunner):
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
if message.app_model_config:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
else:
|
||||
file_extra_config = None
|
||||
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
if not file_extra_config:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
@ -27,11 +27,11 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
_historic_prompt_messages: list[PromptMessage] | None = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
||||
_instruction: str = "" # FIXME this must be str for now
|
||||
_query: str | None = None
|
||||
_prompt_messages_tools: list[PromptMessageTool] = []
|
||||
|
||||
def run(
|
||||
self,
|
||||
@ -42,7 +42,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
@ -55,19 +54,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
@ -119,7 +116,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
if not isinstance(chunks, Generator):
|
||||
raise ValueError("Expected streaming response from LLM")
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
@ -139,25 +143,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += chunk
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
if self._agent_scratchpad is not None:
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
@ -252,6 +256,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
if self.variables_pool is not None and self.db_variables_pool is not None:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -269,7 +275,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: Mapping[str, Tool],
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
@ -309,7 +315,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as is not None and self.variables_pool:
|
||||
# FIXME the save_as type is confusing, it should be a string or not
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@ -332,7 +342,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
@ -369,7 +379,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
self, current_session_messages: Optional[list[PromptMessage]] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
@ -381,7 +391,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
if not isinstance(message.content, str | None):
|
||||
raise NotImplementedError("expected str type")
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@ -400,8 +411,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
if not current_scratchpad:
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
|
||||
@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
if not self.app_config.agent:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
@ -83,10 +83,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
|
||||
@ -1,21 +1,18 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: ToolProviderType
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
@ -69,7 +66,7 @@ class AgentEntity(BaseModel):
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(StrEnum):
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
@ -81,13 +78,5 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
tools: list[AgentToolEntity] | None = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
Agent Invoke Message.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -46,20 +46,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -109,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if isinstance(chunks, Generator):
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@ -126,7 +124,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@ -142,7 +140,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
@ -153,7 +151,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@ -185,6 +183,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@ -243,12 +243,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@ -300,6 +303,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -330,7 +335,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -353,7 +360,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
@ -376,7 +383,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
|
||||
@ -1,89 +0,0 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
"""
|
||||
Inherits from ToolProviderIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class AgentStrategyProviderEntity(BaseModel):
|
||||
identity: AgentStrategyProviderIdentity
|
||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||
|
||||
|
||||
class AgentStrategyIdentity(ToolIdentity):
|
||||
"""
|
||||
Inherits from ToolIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
||||
@ -1,42 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
|
||||
|
||||
class BaseAgentStrategy(ABC):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
"""
|
||||
Get the parameters for the agent strategy.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
pass
|
||||
@ -1,59 +0,0 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class PluginAgentStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
declaration: AgentStrategyEntity
|
||||
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||
self.tenant_id = tenant_id
|
||||
self.declaration = declaration
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Initialize the parameters for the agent strategy.
|
||||
"""
|
||||
for parameter in self.declaration.parameters:
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
|
||||
initialized_params = self.initialize_parameters(params)
|
||||
params = convert_parameters_to_plugin_format(initialized_params)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
agent_provider=self.declaration.identity.provider,
|
||||
agent_strategy=self.declaration.identity.name,
|
||||
agent_params=params,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
@ -4,8 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
@ -64,14 +63,14 @@ class ModelConfigConverter:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@ -2,9 +2,8 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@ -54,18 +53,9 @@ class ModelConfigManager:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
|
||||
@ -37,6 +37,17 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -54,31 +65,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -156,8 +156,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -169,14 +167,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -213,8 +205,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -234,7 +224,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ def _process_future(
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher:
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices(language=language)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
|
||||
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
@ -57,7 +58,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -83,12 +84,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -122,4 +123,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -17,7 +17,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
@ -220,9 +219,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
|
||||
)
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@ -250,7 +247,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -643,10 +640,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
session.commit()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@ -30,6 +29,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -41,17 +51,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -61,7 +60,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -71,7 +70,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -183,7 +182,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": contextvars.copy_context(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@ -208,7 +206,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
@ -223,9 +220,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
|
||||
@ -8,16 +8,18 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -62,8 +64,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
@ -84,8 +86,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
@ -97,8 +99,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=dict(inputs),
|
||||
query=query or "",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
@ -154,9 +156,9 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -171,7 +173,16 @@ class AgentChatAppRunner(AppRunner):
|
||||
return
|
||||
|
||||
agent_entity = app_config.agent
|
||||
assert agent_entity is not None
|
||||
if not agent_entity:
|
||||
raise ValueError("Agent entity not found")
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
@ -182,9 +193,9 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -232,6 +243,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -248,3 +261,73 @@ class AgentChatAppRunner(AppRunner):
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables | None = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,9 +51,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
def convert_stream_full_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -79,12 +80,13 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
def convert_stream_simple_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -116,4 +118,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -14,15 +14,21 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
cls,
|
||||
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
|
||||
invoke_from: InvokeFrom,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
@ -30,8 +36,12 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@ -49,14 +59,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.file import File, FileUploadConfig
|
||||
@ -139,21 +138,3 @@ class BaseAppGenerator:
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
|
||||
"""
|
||||
Convert messages into event stream
|
||||
"""
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, (Mapping, dict)):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
return gen()
|
||||
|
||||
@ -2,7 +2,7 @@ import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
@ -115,7 +115,7 @@ class AppQueueManager:
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -52,8 +52,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -79,12 +80,13 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -116,4 +118,4 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str | Mapping[str, Any], None, None]: ...
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -56,8 +56,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,8 +51,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -77,12 +78,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -113,4 +115,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -36,13 +36,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -50,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -64,26 +64,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
@ -124,10 +124,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@ -149,18 +146,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@ -199,10 +185,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
user: Account,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -238,8 +224,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
@ -36,8 +36,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -61,12 +62,13 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -92,4 +94,4 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -13,7 +13,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
@ -191,9 +190,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
|
||||
)
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@ -530,10 +527,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@ -28,7 +27,6 @@ from core.app.entities.queue_entities import (
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
@ -241,7 +239,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
@ -376,19 +373,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
|
||||
@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
@ -41,7 +41,6 @@ class QueueEvent(StrEnum):
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@ -281,7 +280,6 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
@ -317,22 +315,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentLogEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_LOG
|
||||
id: str
|
||||
label: str
|
||||
node_execution_id: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -61,7 +60,6 @@ class StreamEvent(Enum):
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@ -249,7 +247,6 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@ -699,26 +696,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentLogStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
label: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
||||
@ -24,8 +24,6 @@ class HostingModerationFeature:
|
||||
if isinstance(prompt_message.content, str):
|
||||
text += prompt_message.content + "\n"
|
||||
|
||||
moderation_result = moderation.check_moderation(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
|
||||
)
|
||||
moderation_result = moderation.check_moderation(model_config, text)
|
||||
|
||||
return moderation_result
|
||||
|
||||
@ -215,9 +215,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||
|
||||
@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@ -25,7 +24,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
@ -322,8 +320,9 @@ class WorkflowCycleManage:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata_dict = dict(event.execution_metadata or {})
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
@ -541,7 +540,6 @@ class WorkflowCycleManage:
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
@ -845,24 +843,3 @@ class WorkflowCycleManage:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return cached_workflow_node_execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
Handle agent log
|
||||
:param task_id: task id
|
||||
:param event: agent log event
|
||||
:return:
|
||||
"""
|
||||
return AgentLogStreamResponse(
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Iterable, Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Iterable[ToolInvokeMessage] | str,
|
||||
tool_outputs: Sequence[ToolInvokeMessage] | str,
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
|
||||
@ -1,26 +1,5 @@
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
|
||||
|
||||
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def on_tool_execution(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Iterable[ToolInvokeMessage],
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
for tool_output in tool_outputs:
|
||||
print_text("\n[on_tool_execution]\n", color=self.color)
|
||||
print_text("Tool: " + tool_name + "\n", color=self.color)
|
||||
print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
yield tool_output
|
||||
|
||||
@ -1 +0,0 @@
|
||||
DEFAULT_PLUGIN_ID = "langgenius"
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CommonParameterType(StrEnum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
SYSTEM_FILES = "system-files"
|
||||
BOOLEAN = "boolean"
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
|
||||
|
||||
class AppSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = "workflow"
|
||||
COMPLETION = "completion"
|
||||
|
||||
|
||||
class ModelSelectorScope(StrEnum):
|
||||
LLM = "llm"
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
class ToolSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CUSTOM = "custom"
|
||||
BUILTIN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
@ -2,14 +2,13 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@ -19,15 +18,16 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from extensions.ext_database import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
@ -99,10 +99,9 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
|
||||
copy_credentials = (
|
||||
self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
|
||||
)
|
||||
if self.system_configuration.credentials is None:
|
||||
return None
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
if (
|
||||
@ -141,9 +140,6 @@ class ProviderConfiguration(BaseModel):
|
||||
if current_quota_configuration is None:
|
||||
return None
|
||||
|
||||
if not current_quota_configuration:
|
||||
return SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
return (
|
||||
SystemConfigurationStatus.ACTIVE
|
||||
if current_quota_configuration.is_valid
|
||||
@ -157,7 +153,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
|
||||
def get_custom_credentials(self, obfuscated: bool = False):
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@ -179,7 +175,7 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -223,7 +219,6 @@ class ProviderConfiguration(BaseModel):
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
@ -251,13 +246,13 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_record = Provider()
|
||||
provider_record.tenant_id = self.tenant_id
|
||||
provider_record.provider_name = self.provider.provider
|
||||
provider_record.provider_type = ProviderType.CUSTOM.value
|
||||
provider_record.encrypted_config = json.dumps(credentials)
|
||||
provider_record.is_valid = True
|
||||
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
@ -332,7 +327,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel | None, dict]:
|
||||
) -> tuple[Optional[ProviderModel], dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@ -375,7 +370,6 @@ class ProviderConfiguration(BaseModel):
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@ -406,13 +400,14 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model_record = ProviderModel()
|
||||
provider_model_record.tenant_id = self.tenant_id
|
||||
provider_model_record.provider_name = self.provider.provider
|
||||
provider_model_record.model_name = model
|
||||
provider_model_record.model_type = model_type.to_origin_model_type()
|
||||
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||
provider_model_record.is_valid = True
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
@ -479,12 +474,13 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.enabled = True
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
@ -513,12 +509,13 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.enabled = False
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=False,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
@ -579,12 +576,13 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.load_balancing_enabled = True
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
@ -613,17 +611,25 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.load_balancing_enabled = False
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
:return:
|
||||
"""
|
||||
return model_provider_factory.get_provider_instance(self.provider.provider)
|
||||
|
||||
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get current model type instance.
|
||||
@ -631,19 +637,11 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
# Get provider instance
|
||||
provider_instance = self.get_provider_instance()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
return provider_instance.get_model_instance(model_type)
|
||||
|
||||
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
||||
"""
|
||||
@ -670,10 +668,11 @@ class ProviderConfiguration(BaseModel):
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider()
|
||||
preferred_model_provider.tenant_id = self.tenant_id
|
||||
preferred_model_provider.provider_name = self.provider.provider
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value,
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
db.session.commit()
|
||||
@ -738,14 +737,13 @@ class ProviderConfiguration(BaseModel):
|
||||
:param only_active: only active models
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
provider_instance = self.get_provider_instance()
|
||||
|
||||
model_types: list[ModelType] = []
|
||||
model_types = []
|
||||
if model_type:
|
||||
model_types.append(model_type)
|
||||
else:
|
||||
model_types = list(provider_schema.supported_model_types)
|
||||
model_types = list(provider_instance.get_provider_schema().supported_model_types)
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
|
||||
@ -754,11 +752,11 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
)
|
||||
|
||||
if only_active:
|
||||
@ -769,26 +767,23 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_system_provider_models(
|
||||
self,
|
||||
model_types: Sequence[ModelType],
|
||||
provider_schema: ProviderEntity,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_schema: provider schema
|
||||
:param provider_instance: provider instance
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
for model_type in model_types:
|
||||
for m in provider_schema.models:
|
||||
if m.model_type != model_type:
|
||||
continue
|
||||
|
||||
for m in provider_instance.models(model_type):
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model in model_setting_map:
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
@ -809,7 +804,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in provider_schema.configurate_methods:
|
||||
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
should_use_custom_model = False
|
||||
@ -830,22 +825,18 @@ class ProviderConfiguration(BaseModel):
|
||||
]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = (
|
||||
self.system_configuration.credentials.copy()
|
||||
if self.system_configuration.credentials
|
||||
else {}
|
||||
)
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
if self.system_configuration.credentials is not None:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = self.get_model_schema(
|
||||
model_type=restrict_model.model_type,
|
||||
model=restrict_model.model,
|
||||
credentials=copy_credentials,
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
@ -890,15 +881,15 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_custom_provider_models(
|
||||
self,
|
||||
model_types: Sequence[ModelType],
|
||||
provider_schema: ProviderEntity,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_schema: provider schema
|
||||
:param provider_instance: provider instance
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
@ -912,10 +903,8 @@ class ProviderConfiguration(BaseModel):
|
||||
if model_type not in self.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
for m in provider_schema.models:
|
||||
if m.model_type != model_type:
|
||||
continue
|
||||
|
||||
models = provider_instance.models(model_type)
|
||||
for m in models:
|
||||
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
load_balancing_enabled = False
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
@ -947,10 +936,10 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_model_schema = self.get_model_schema(
|
||||
model_type=model_configuration.model_type,
|
||||
model=model_configuration.model,
|
||||
credentials=model_configuration.credentials,
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
model_configuration.model_type
|
||||
).get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model, model_configuration.credentials
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
@ -978,7 +967,7 @@ class ProviderConfiguration(BaseModel):
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
fetch_from=custom_model_schema.fetch_from,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
@ -1051,9 +1040,6 @@ class ProviderConfigurations(BaseModel):
|
||||
return list(self.values())
|
||||
|
||||
def __getitem__(self, key):
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
@ -1065,11 +1051,8 @@ class ProviderConfigurations(BaseModel):
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return iter(self.configurations.values())
|
||||
|
||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
def get(self, key, default=None):
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
@ -1078,6 +1061,7 @@ class ProviderModelBundle(BaseModel):
|
||||
"""
|
||||
|
||||
configuration: ProviderConfiguration
|
||||
provider_instance: ModelProvider
|
||||
model_type_instance: AIModel
|
||||
|
||||
# pydantic configs
|
||||
|
||||
@ -1,34 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
@ -132,55 +108,3 @@ class ModelSettings(BaseModel):
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class BasicProviderConfig(BaseModel):
|
||||
"""
|
||||
Base model class for common provider settings like credentials
|
||||
"""
|
||||
|
||||
class Type(Enum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
type: Type = Field(..., description="The type of the credentials")
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
|
||||
|
||||
class ProviderConfig(BasicProviderConfig):
|
||||
"""
|
||||
Model class for common provider settings like credentials
|
||||
"""
|
||||
|
||||
class Option(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
def to_basic_provider_config(self) -> BasicProviderConfig:
|
||||
return BasicProviderConfig(type=self.type, name=self.name)
|
||||
|
||||
@ -20,41 +20,6 @@ def get_signed_file_url(upload_file_id: str) -> str:
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/upload/for-plugin"
|
||||
|
||||
if user_id is None:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
|
||||
|
||||
|
||||
def verify_plugin_file_signature(
|
||||
*, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str
|
||||
) -> bool:
|
||||
if user_id is None:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -124,17 +124,6 @@ class File(BaseModel):
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"mime_type": self.mime_type,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -1,17 +0,0 @@
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
|
||||
def download_with_size_limit(url, max_download_size: int, **kwargs):
|
||||
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
|
||||
if response.status_code == 404:
|
||||
raise ValueError("file not found")
|
||||
|
||||
total_size = 0
|
||||
chunks = []
|
||||
for chunk in response.iter_bytes():
|
||||
total_size += len(chunk)
|
||||
if total_size > max_download_size:
|
||||
raise ValueError("Max file size reached")
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
return content
|
||||
@ -1,35 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
@ -1,35 +1,28 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
|
||||
def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
|
||||
moderation_config = hosting_configuration.moderation_config
|
||||
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
|
||||
if (
|
||||
moderation_config
|
||||
and moderation_config.enabled is True
|
||||
and openai_provider_name in hosting_configuration.provider_map
|
||||
and hosting_configuration.provider_map[openai_provider_name].enabled is True
|
||||
and "openai" in hosting_configuration.provider_map
|
||||
and hosting_configuration.provider_map["openai"].enabled is True
|
||||
):
|
||||
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
|
||||
provider_name = model_config.provider
|
||||
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
|
||||
hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
|
||||
|
||||
if hosting_openai_config.credentials is None:
|
||||
return False
|
||||
hosting_openai_config = hosting_configuration.provider_map["openai"]
|
||||
assert hosting_openai_config is not None
|
||||
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
@ -41,20 +34,15 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
)
|
||||
model_type_instance = cast(ModerationModel, model_type_instance)
|
||||
model_type_instance = OpenAIModerationModel()
|
||||
# FIXME, for type hint using assert or raise ValueError is better here?
|
||||
moderation_result = model_type_instance.invoke(
|
||||
model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
|
||||
model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk
|
||||
)
|
||||
|
||||
if moderation_result is True:
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as ex:
|
||||
logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
|
||||
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
|
||||
@ -8,7 +8,6 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
ENDPOINT = "endpoint"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class SignedUrlParams(BaseModel):
|
||||
sign_key: str = Field(..., description="The sign key")
|
||||
timestamp: str = Field(..., description="Timestamp")
|
||||
nonce: str = Field(..., description="Nonce")
|
||||
sign: str = Field(..., description="Signature")
|
||||
|
||||
|
||||
class UrlSigner:
|
||||
@classmethod
|
||||
def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
|
||||
signed_url_params = cls.get_signed_url_params(sign_key, prefix)
|
||||
return (
|
||||
f"{url}?timestamp={signed_url_params.timestamp}"
|
||||
f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = cls._sign(sign_key, timestamp, nonce, prefix)
|
||||
|
||||
return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)
|
||||
|
||||
@classmethod
|
||||
def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
|
||||
recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)
|
||||
|
||||
return sign == recalculated_sign
|
||||
|
||||
@classmethod
|
||||
def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
|
||||
if not dify_config.SECRET_KEY:
|
||||
raise Exception("SECRET_KEY is not set")
|
||||
|
||||
data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return encoded_sign
|
||||
@ -4,9 +4,9 @@ from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel
|
||||
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class HostingQuota(BaseModel):
|
||||
@ -48,12 +48,12 @@ class HostingConfiguration:
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map["openai"] = self.init_openai()
|
||||
self.provider_map["anthropic"] = self.init_anthropic()
|
||||
self.provider_map["minimax"] = self.init_minimax()
|
||||
self.provider_map["spark"] = self.init_spark()
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@ -240,14 +240,7 @@ class HostingConfiguration:
|
||||
@staticmethod
|
||||
def init_moderation_config() -> HostedModerationConfig:
|
||||
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
|
||||
providers = dify_config.HOSTED_MODERATION_PROVIDERS.split(",")
|
||||
hosted_providers = []
|
||||
for provider in providers:
|
||||
if "/" not in provider:
|
||||
provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}"
|
||||
hosted_providers.append(provider)
|
||||
|
||||
return HostedModerationConfig(enabled=True, providers=hosted_providers)
|
||||
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
|
||||
|
||||
return HostedModerationConfig(enabled=False)
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -618,8 +618,10 @@ class IndexingRunner:
|
||||
|
||||
tokens = 0
|
||||
if embedding_model_instance:
|
||||
page_content_list = [document.page_content for document in chunk_documents]
|
||||
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
|
||||
tokens += sum(
|
||||
embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
||||
|
||||
@ -48,7 +48,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
),
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
@ -101,7 +101,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
@ -110,7 +110,7 @@ class LLMGenerator:
|
||||
questions = output_parser.parse(cast(str, response.message.content))
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
@ -150,7 +150,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,7 +200,7 @@ class LLMGenerator:
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
except InvokeError as e:
|
||||
@ -236,7 +236,7 @@ class LLMGenerator:
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
@ -248,7 +248,7 @@ class LLMGenerator:
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
@ -301,7 +301,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user