mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 12:09:27 +08:00
Compare commits
70 Commits
feat/trigg
...
deploy/rag
| Author | SHA1 | Date | |
|---|---|---|---|
| 6bf55e1cba | |||
| d0e9fccc9d | |||
| 74d938a8d2 | |||
| 4cebaa331e | |||
| 9b36059292 | |||
| a4acc64afd | |||
| 25c69ac540 | |||
| 96a0b9991e | |||
| 2913d17fe2 | |||
| d9e45a1abe | |||
| 24b4289d6c | |||
| fb6ccccc3d | |||
| 8b74ae683a | |||
| dd08957381 | |||
| b55c354139 | |||
| 500836ba25 | |||
| 407323f817 | |||
| 2e2c87c5a1 | |||
| f4522fd695 | |||
| 760a2c656c | |||
| 8940decd1b | |||
| 0c4193bd91 | |||
| cd40cde790 | |||
| c60c754ac9 | |||
| ef80d3b707 | |||
| 24e8d21b3f | |||
| d823da18db | |||
| 4174462190 | |||
| 1e3df09fc6 | |||
| 75a10c276c | |||
| 50050527eb | |||
| a39b185627 | |||
| 15270f09af | |||
| f6a5ac0698 | |||
| 2b79da722b | |||
| 71d69e43cd | |||
| 5bc6e8a433 | |||
| 68076f2e22 | |||
| 8c38363038 | |||
| 345ac8333c | |||
| 2375047ef0 | |||
| 857a48012e | |||
| 208fe3d7de | |||
| 92cddbcc02 | |||
| 599b53c9cb | |||
| 062b173c66 | |||
| db690013fd | |||
| e93bfe3d41 | |||
| ab910c736c | |||
| 4047a6bb12 | |||
| df2478dc26 | |||
| 4cc3f6045b | |||
| 1550316b8d | |||
| 87394d2512 | |||
| bad59c95bc | |||
| 9f138ef246 | |||
| 6453fc4973 | |||
| f62f926537 | |||
| b3dafd913b | |||
| b2d8a7eaf1 | |||
| 3e54414191 | |||
| a173546c8d | |||
| aa69d90489 | |||
| 4ba1292455 | |||
| bb01c31f30 | |||
| cd90b2ca9e | |||
| 9a65350cf7 | |||
| 680eb7a9f6 | |||
| 878420463c | |||
| 4692e20daf |
@ -1,16 +1,16 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -2,8 +2,6 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@ -8,8 +8,7 @@ on:
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "deploy/rag-dev"
|
||||
- "feat/rag-2"
|
||||
- "hotfix/**"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
4
.github/workflows/deploy-dev.yml
vendored
4
.github/workflows/deploy-dev.yml
vendored
@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -231,5 +231,7 @@ api/.env.backup
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.serena
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
@ -85,4 +85,3 @@ pnpm test # Run Jest tests
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||
|
||||
5
Makefile
5
Makefile
@ -61,8 +61,9 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format and check with fixes..."
|
||||
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
||||
@ -304,6 +304,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
@ -436,9 +438,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -517,12 +516,6 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
|
||||
@ -30,6 +30,7 @@ select = [
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"T201", # print-found
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
@ -91,11 +92,18 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -1,20 +1,11 @@
|
||||
import logging
|
||||
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
_log("gRPC patched with gevent.")
|
||||
print("gRPC patched with gevent.", flush=True) # noqa: T201
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
_log("psycopg2 patched with gevent.")
|
||||
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
422
api/commands.py
422
api/commands.py
@ -10,16 +10,17 @@ from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account.email = new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
@ -739,18 +739,18 @@ where sites.id is null limit 1000"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
print(f"App {app_id} not found")
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
print(f"Fix failed for app {app.id}")
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
print(f"Fixing missing site for app {app.id}")
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
@ -1227,55 +1227,6 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@ -1497,41 +1448,52 @@ def transform_datasource_credentials():
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
@ -1544,37 +1506,48 @@ def transform_datasource_credentials():
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
@ -1587,36 +1560,45 @@ def transform_datasource_credentials():
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -18,3 +18,18 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN: str = Field(
|
||||
description="Token for uploading knowledge pipeline template.",
|
||||
default="",
|
||||
)
|
||||
|
||||
KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT: str = Field(
|
||||
description="Knowledge pipeline template copyright.",
|
||||
default="Copyright 2023 Dify",
|
||||
)
|
||||
|
||||
KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY: str = Field(
|
||||
description="Knowledge pipeline template privacy policy.",
|
||||
default="https://dify.ai",
|
||||
)
|
||||
|
||||
@ -154,17 +154,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@ -961,22 +950,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -1100,7 +1073,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
|
||||
@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
|
||||
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
|
||||
default="DEFAULT_ANALYZER",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
@ -37,3 +37,15 @@ class OceanBaseVectorConfig(BaseSettings):
|
||||
"with older versions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
||||
description=(
|
||||
"Fulltext parser to use for text indexing. "
|
||||
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
|
||||
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
|
||||
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
|
||||
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
|
||||
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
|
||||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,10 +30,10 @@ class NacosHttpClient:
|
||||
params = {}
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
@ -78,7 +78,7 @@ class NacosHttpClient:
|
||||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp = httpx.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
|
||||
@ -9,8 +9,6 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
"""
|
||||
@ -43,11 +41,3 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -87,7 +87,6 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -224,21 +223,6 @@ api.add_resource(
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"activate",
|
||||
@ -304,7 +288,6 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
|
||||
@ -12,7 +12,6 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -199,11 +198,13 @@ class InstructionGenerateApi(Resource):
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
|
||||
@ -62,6 +62,9 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
|
||||
@ -50,8 +50,9 @@ class DailyMessageStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@ -187,8 +188,9 @@ class DailyTerminalsStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@ -259,8 +261,9 @@ class DailyTokenCostStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@ -340,8 +343,9 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@ -426,8 +430,9 @@ LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@ -502,8 +507,9 @@ class AverageResponseTimeStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@ -576,8 +582,9 @@ class TokensPerSecondStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
@ -20,7 +20,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
@ -36,7 +35,6 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.trigger_debug_service import TriggerDebugService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1006,165 +1004,3 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={
|
||||
"app_id": "Application ID",
|
||||
"node_id": "Node ID"
|
||||
})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerNodeRequest",
|
||||
{
|
||||
"trigger_name": fields.String(required=True, description="Trigger name"),
|
||||
"subscription_id": fields.String(required=True, description="Subscription ID"),
|
||||
}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
try:
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
user_inputs = event.model_dump()
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query="",
|
||||
files=[],
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
"trigger_name": fields.String(required=True, description="Trigger name"),
|
||||
"subscription_id": fields.String(required=True, description="Subscription ID"),
|
||||
}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
workflow_args = {
|
||||
"inputs": event.model_dump(),
|
||||
"query": "",
|
||||
"files": [],
|
||||
}
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
workflow_args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
@ -1,249 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account, AppMode
|
||||
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
|
||||
class PluginTriggerApi(Resource):
|
||||
"""Workflow Plugin Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def post(self, app_model):
|
||||
"""Create plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=False, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=False, location="json")
|
||||
parser.add_argument("trigger_name", type=str, required=False, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args["provider_id"],
|
||||
trigger_name=args["trigger_name"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def get(self, app_model):
|
||||
"""Get plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def put(self, app_model):
|
||||
"""Update plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def delete(self, app_model):
|
||||
"""Delete plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = args["node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.filter(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
# Add computed fields for marshal_with
|
||||
base_url = dify_config.SERVICE_API_URL
|
||||
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
|
||||
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
@ -101,8 +101,10 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
except httpx.RequestError as e:
|
||||
error_text = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
||||
@ -782,7 +782,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@ -809,6 +808,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@ -838,7 +838,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@ -863,6 +862,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@ -211,13 +212,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
|
||||
@ -14,7 +14,10 @@ from controllers.console.wraps import (
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
PipelineBuiltInTemplateEntity,
|
||||
PipelineTemplateInfoEntity,
|
||||
)
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -26,12 +29,6 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -146,6 +143,186 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class PipelineTemplateInstallApi(Resource):
|
||||
"""API endpoint for installing built-in pipeline templates"""
|
||||
|
||||
def post(self):
|
||||
"""
|
||||
Install a built-in pipeline template
|
||||
|
||||
Args:
|
||||
template_id: The template ID from URL parameter
|
||||
|
||||
Returns:
|
||||
Success response or error with appropriate HTTP status
|
||||
"""
|
||||
try:
|
||||
# Extract and validate Bearer token
|
||||
auth_token = self._extract_bearer_token()
|
||||
|
||||
# Parse and validate request parameters
|
||||
template_args = self._parse_template_args()
|
||||
|
||||
# Process uploaded template file
|
||||
file_content = self._process_template_file()
|
||||
|
||||
# Create template entity
|
||||
pipeline_built_in_template_entity = PipelineBuiltInTemplateEntity(**template_args)
|
||||
|
||||
# Install the template
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.install_built_in_pipeline_template(
|
||||
pipeline_built_in_template_entity, file_content, auth_token
|
||||
)
|
||||
|
||||
return {"result": "success", "message": "Template installed successfully"}, 200
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Validation error in template installation")
|
||||
return {"error": str(e)}, 400
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in template installation")
|
||||
return {"error": "An unexpected error occurred during template installation"}, 500
|
||||
|
||||
def _extract_bearer_token(self) -> str:
|
||||
"""
|
||||
Extract and validate Bearer token from Authorization header
|
||||
|
||||
Returns:
|
||||
The extracted token string
|
||||
|
||||
Raises:
|
||||
ValueError: If token is missing or invalid
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "").strip()
|
||||
|
||||
if not auth_header:
|
||||
raise ValueError("Authorization header is required")
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise ValueError("Authorization header must start with 'Bearer '")
|
||||
|
||||
token_parts = auth_header.split(" ", 1)
|
||||
if len(token_parts) != 2:
|
||||
raise ValueError("Invalid Authorization header format")
|
||||
|
||||
auth_token = token_parts[1].strip()
|
||||
if not auth_token:
|
||||
raise ValueError("Bearer token cannot be empty")
|
||||
|
||||
return auth_token
|
||||
|
||||
def _parse_template_args(self) -> dict:
|
||||
"""
|
||||
Parse and validate template arguments from form data
|
||||
|
||||
Args:
|
||||
template_id: The template ID from URL
|
||||
|
||||
Returns:
|
||||
Dictionary of validated template arguments
|
||||
"""
|
||||
# Use reqparse for consistent parameter parsing
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
"template_id",
|
||||
type=str,
|
||||
location="form",
|
||||
required=False,
|
||||
help="Template ID for updating existing template"
|
||||
)
|
||||
parser.add_argument(
|
||||
"language",
|
||||
type=str,
|
||||
location="form",
|
||||
required=True,
|
||||
default="en-US",
|
||||
choices=["en-US", "zh-CN", "ja-JP"],
|
||||
help="Template language code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"name",
|
||||
type=str,
|
||||
location="form",
|
||||
required=True,
|
||||
default="New Pipeline Template",
|
||||
help="Template name (1-200 characters)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
location="form",
|
||||
required=False,
|
||||
default="",
|
||||
help="Template description (max 1000 characters)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Additional validation
|
||||
if args.get("name"):
|
||||
args["name"] = self._validate_name(args["name"])
|
||||
|
||||
if args.get("description") and len(args["description"]) > 1000:
|
||||
raise ValueError("Description must not exceed 1000 characters")
|
||||
|
||||
# Filter out None values
|
||||
return {k: v for k, v in args.items() if v is not None}
|
||||
|
||||
def _validate_name(self, name: str) -> str:
|
||||
"""
|
||||
Validate template name
|
||||
|
||||
Args:
|
||||
name: Template name to validate
|
||||
|
||||
Returns:
|
||||
Validated and trimmed name
|
||||
|
||||
Raises:
|
||||
ValueError: If name is invalid
|
||||
"""
|
||||
name = name.strip()
|
||||
if not name or len(name) < 1 or len(name) > 200:
|
||||
raise ValueError("Template name must be between 1 and 200 characters")
|
||||
return name
|
||||
|
||||
def _process_template_file(self) -> str:
|
||||
"""
|
||||
Process and validate uploaded template file
|
||||
|
||||
Returns:
|
||||
File content as string
|
||||
|
||||
Raises:
|
||||
ValueError: If file is missing or invalid
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
raise ValueError("Template file is required")
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
# Validate file
|
||||
if not file or not file.filename:
|
||||
raise ValueError("No file selected")
|
||||
|
||||
filename = file.filename.strip()
|
||||
if not filename:
|
||||
raise ValueError("File name cannot be empty")
|
||||
|
||||
# Check file extension
|
||||
if not filename.lower().endswith(".pipeline"):
|
||||
raise ValueError("Template file must be a pipeline file (.pipeline)")
|
||||
|
||||
try:
|
||||
file_content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError("Template file must be valid UTF-8 text")
|
||||
|
||||
return file_content
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
@ -162,3 +339,7 @@ api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateInstallApi,
|
||||
"/rag/pipeline/built-in/templates/install",
|
||||
)
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from packaging import version
|
||||
|
||||
@ -57,7 +57,11 @@ class VersionApi(Resource):
|
||||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
|
||||
@ -516,20 +516,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
@ -21,8 +21,8 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
@ -1,589 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if there's a system OAuth client
|
||||
system_client = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_configured": bool(custom_params),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params or {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@ -9,9 +9,10 @@ from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
@ -194,6 +195,50 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
) -> EndUser:
|
||||
"""Create end user in existing session"""
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=mcp_server_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush() # Use flush instead of commit to keep transaction open
|
||||
session.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
def _handle_mcp_request(
|
||||
self,
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
@ -1,41 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
raise NotFound("Endpoint not found")
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@ -1,46 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
# Get webhook trigger, workflow, and node configuration
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||
|
||||
# Extract request data
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate request against node configuration
|
||||
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
|
||||
if not validation_result["valid"]:
|
||||
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@ -261,6 +261,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
# questions is a list of strings, not a list of Message objects
|
||||
# so we can directly return it
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
|
||||
@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
||||
@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -53,8 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -68,8 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -83,8 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@ -97,8 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -127,26 +119,24 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||
# start node get inputs
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@ -165,10 +155,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@ -195,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -210,7 +196,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -246,7 +231,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
@ -440,16 +424,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
@ -482,7 +465,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -34,7 +34,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -44,7 +43,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@ -53,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -107,7 +87,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
root_node_id=self._root_node_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -79,7 +80,6 @@ class WorkflowBasedAppRunner:
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
@ -113,22 +113,88 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
(either single iteration or single loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
|
||||
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool, graph_runtime_state)
|
||||
|
||||
Raises:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
||||
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
node_id: The node ID to execute
|
||||
user_inputs: User inputs for the node
|
||||
graph_runtime_state: The graph runtime state
|
||||
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
|
||||
node_type_label: Label for error messages ('iteration' or 'loop')
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool)
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
@ -146,23 +212,27 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
@ -191,30 +261,26 @@ class WorkflowBasedAppRunner:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
@ -235,120 +301,44 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
||||
@ -1,388 +0,0 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue]
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue]
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType]
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
@ -1,17 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
@ -1,9 +0,0 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@ -1,43 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
@ -1,35 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get custom provider record.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderCredential.id).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
),
|
||||
)
|
||||
|
||||
@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def _get_provider_names(self):
|
||||
"""
|
||||
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Find all load balancing configs that use this credential_id
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source,
|
||||
)
|
||||
@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Check if this credential is used in load balancing configs
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
)
|
||||
@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the provider record
|
||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
)
|
||||
@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the custom model record
|
||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
def _switch(s: Session):
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
@ -207,7 +207,6 @@ class ProviderConfig(BasicProviderConfig):
|
||||
required: bool = False
|
||||
default: Union[int, str, float, bool] | None = None
|
||||
options: list[Option] | None = None
|
||||
multiple: bool | None = False
|
||||
label: I18nObject | None = None
|
||||
help: I18nObject | None = None
|
||||
url: str | None = None
|
||||
|
||||
@ -3,7 +3,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
@ -8,7 +8,7 @@ from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
@ -65,13 +65,13 @@ class TraceClient:
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
response = httpx.head(self.endpoint, timeout=5)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
|
||||
@ -417,7 +417,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
print("Weave login successful")
|
||||
logger.info("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
@ -13,7 +13,6 @@ from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(StrEnum):
|
||||
@ -28,56 +27,54 @@ class PluginResourceRequirements(BaseModel):
|
||||
|
||||
class Permission(BaseModel):
|
||||
class Tool(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Model(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
llm: Optional[bool] = Field(default=False)
|
||||
text_embedding: Optional[bool] = Field(default=False)
|
||||
rerank: Optional[bool] = Field(default=False)
|
||||
tts: Optional[bool] = Field(default=False)
|
||||
speech2text: Optional[bool] = Field(default=False)
|
||||
moderation: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
llm: bool | None = Field(default=False)
|
||||
text_embedding: bool | None = Field(default=False)
|
||||
rerank: bool | None = Field(default=False)
|
||||
tts: bool | None = Field(default=False)
|
||||
speech2text: bool | None = Field(default=False)
|
||||
moderation: bool | None = Field(default=False)
|
||||
|
||||
class Node(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Endpoint(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Storage(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
size: int = Field(ge=1024, le=1073741824, default=1048576)
|
||||
|
||||
tool: Optional[Tool] = Field(default=None)
|
||||
model: Optional[Model] = Field(default=None)
|
||||
node: Optional[Node] = Field(default=None)
|
||||
endpoint: Optional[Endpoint] = Field(default=None)
|
||||
storage: Optional[Storage] = Field(default=None)
|
||||
tool: Tool | None = Field(default=None)
|
||||
model: Model | None = Field(default=None)
|
||||
node: Node | None = Field(default=None)
|
||||
endpoint: Endpoint | None = Field(default=None)
|
||||
storage: Storage | None = Field(default=None)
|
||||
|
||||
permission: Optional[Permission] = Field(default=None)
|
||||
permission: Permission | None = Field(default=None)
|
||||
|
||||
|
||||
class PluginCategory(StrEnum):
|
||||
Tool = auto()
|
||||
Model = auto()
|
||||
Extension = auto()
|
||||
AgentStrategy = auto()
|
||||
Datasource = auto()
|
||||
Trigger = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
Datasource = "datasource"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
triggers: Optional[list[str]] = Field(default_factory=list[str])
|
||||
tools: list[str] | None = Field(default_factory=list[str])
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
version: Optional[str] = Field(default=None)
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
version: str | None = Field(default=None)
|
||||
|
||||
@field_validator("minimum_dify_version")
|
||||
@classmethod
|
||||
@ -90,26 +87,25 @@ class PluginDeclaration(BaseModel):
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
version: str = Field(...)
|
||||
author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon_dark: Optional[str] = Field(default=None)
|
||||
icon_dark: str | None = Field(default=None)
|
||||
label: I18nObject
|
||||
category: PluginCategory
|
||||
created_at: datetime.datetime
|
||||
resource: PluginResourceRequirements
|
||||
plugins: Plugins
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
repo: Optional[str] = Field(default=None)
|
||||
repo: str | None = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
tool: ToolProviderEntity | None = None
|
||||
model: ProviderEntity | None = None
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
trigger: Optional[TriggerProviderEntity] = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
@ -123,7 +119,7 @@ class PluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict) -> dict:
|
||||
def validate_category(cls, values: dict):
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
@ -133,8 +129,6 @@ class PluginDeclaration(BaseModel):
|
||||
values["category"] = PluginCategory.Datasource
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
values["category"] = PluginCategory.Trigger
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
@ -196,9 +190,9 @@ class PluginDependency(BaseModel):
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
current_identifier: Optional[str] = None
|
||||
current_identifier: str | None = None
|
||||
|
||||
|
||||
class MissingPluginDependency(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
current_identifier: Optional[str] = None
|
||||
current_identifier: str | None = None
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -14,7 +14,6 @@ from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
@ -206,49 +205,3 @@ class PluginListResponse(BaseModel):
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = auto()
|
||||
UNAUTHORIZED = auto()
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
@ -239,33 +237,3 @@ class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
variables: Mapping[str, Any]
|
||||
|
||||
|
||||
class TriggerInvokeResponse(BaseModel):
|
||||
event: Event
|
||||
|
||||
|
||||
class PluginTriggerDispatchResponse(BaseModel):
|
||||
triggers: list[str]
|
||||
raw_http_response: str
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse:
|
||||
triggers: list[str]
|
||||
response: Response
|
||||
|
||||
def __init__(self, triggers: list[str], response: Response):
|
||||
self.triggers = triggers
|
||||
self.response = response
|
||||
|
||||
@ -15,7 +15,6 @@ class DynamicSelectClient(BasePluginClient):
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
credential_type: str,
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
@ -30,7 +29,6 @@ class DynamicSelectClient(BasePluginClient):
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
|
||||
@ -4,14 +4,13 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
CredentialType,
|
||||
PluginBasicBooleanResponse,
|
||||
PluginToolProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||
|
||||
|
||||
|
||||
@ -1,301 +0,0 @@
|
||||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import (
|
||||
PluginTriggerDispatchResponse,
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
TriggerSubscriptionResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.http_parser import deserialize_response, serialize_request
|
||||
from core.trigger.entities.entities import Subscription
|
||||
from models.provider_ids import GenericProviderID, TriggerProviderID
|
||||
|
||||
|
||||
class PluginTriggerManager(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch trigger providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
|
||||
for trigger in declaration.get("triggers", []):
|
||||
trigger["identity"]["provider"] = provider_id
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/triggers",
|
||||
list[PluginTriggerProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in provider.declaration.triggers:
|
||||
trigger.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch trigger provider for the given tenant and plugin.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for trigger in data.get("declaration", {}).get("triggers", []):
|
||||
trigger["identity"]["provider"] = str(provider_id)
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/trigger",
|
||||
PluginTriggerProviderEntity,
|
||||
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = str(provider_id)
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in response.declaration.triggers:
|
||||
trigger.identity.provider = str(provider_id)
|
||||
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
trigger: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/invoke",
|
||||
TriggerInvokeResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"trigger": trigger,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerInvokeResponse(event=resp.event)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Mapping[str, Any],
|
||||
request: Request,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch an event to triggers.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||
PluginTriggerDispatchResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerDispatchResponse(
|
||||
triggers=resp.triggers,
|
||||
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
|
||||
)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: Mapping[str, str],
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
"endpoint": endpoint,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for subscribe")
|
||||
|
||||
def unsubscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for refresh")
|
||||
@ -1,159 +0,0 @@
|
||||
from io import BytesIO
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
method = request.method
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
if "Content-Type" in headers:
|
||||
environ["CONTENT_TYPE"] = headers.get("Content-Type")
|
||||
|
||||
if "Content-Length" in headers:
|
||||
environ["CONTENT_LENGTH"] = headers.get("Content-Length")
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
@ -513,6 +513,21 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_names(provider_name: str) -> list[str]:
|
||||
"""
|
||||
provider_name: `openai` or `langgenius/openai/openai`
|
||||
return: [`openai`, `langgenius/openai/openai`]
|
||||
"""
|
||||
provider_names = [provider_name]
|
||||
model_provider_id = ModelProviderID(provider_name)
|
||||
if model_provider_id.is_langgenius():
|
||||
if "/" in provider_name:
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
else:
|
||||
provider_names.append(str(model_provider_id))
|
||||
return provider_names
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
@ -525,7 +540,10 @@ class ProviderManager:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.where(
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
@ -554,7 +572,7 @@ class ProviderManager:
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
@ -9,11 +10,24 @@ from pymochow import MochowClient # type: ignore
|
||||
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from pymochow.configuration import Configuration # type: ignore
|
||||
from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
|
||||
from pymochow.model.schema import (
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
InvertedIndex,
|
||||
InvertedIndexAnalyzer,
|
||||
InvertedIndexFieldAttribute,
|
||||
InvertedIndexParams,
|
||||
InvertedIndexParseMode,
|
||||
Schema,
|
||||
VectorIndex,
|
||||
) # type: ignore
|
||||
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field as VDBField
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -22,6 +36,8 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaiduConfig(BaseModel):
|
||||
endpoint: str
|
||||
@ -30,9 +46,11 @@ class BaiduConfig(BaseModel):
|
||||
api_key: str
|
||||
database: str
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
metric_type: str = "IP"
|
||||
shard: int = 1
|
||||
replicas: int = 3
|
||||
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
|
||||
inverted_index_parser_mode: str = "COARSE_MODE"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -49,13 +67,9 @@ class BaiduConfig(BaseModel):
|
||||
|
||||
|
||||
class BaiduVector(BaseVector):
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
field_text: str = "text"
|
||||
field_metadata: str = "metadata"
|
||||
field_app_id: str = "app_id"
|
||||
field_annotation_id: str = "annotation_id"
|
||||
index_vector: str = "vector_idx"
|
||||
vector_index: str = "vector_idx"
|
||||
filtering_index: str = "filtering_idx"
|
||||
inverted_index: str = "content_inverted_idx"
|
||||
|
||||
def __init__(self, collection_name: str, config: BaiduConfig):
|
||||
super().__init__(collection_name)
|
||||
@ -74,8 +88,6 @@ class BaiduVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
@ -84,29 +96,31 @@ class BaiduVector(BaseVector):
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
for i in range(start, end, 1):
|
||||
metadata = documents[i].metadata
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
id=metadata.get("doc_id", str(uuid.uuid4())),
|
||||
page_content=documents[i].page_content,
|
||||
metadata=metadata,
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadatas[i]),
|
||||
app_id=metadatas[i].get("app_id", ""),
|
||||
annotation_id=metadatas[i].get("annotation_id", ""),
|
||||
)
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.index_vector)
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.index_vector)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
if res and res.code == 0:
|
||||
return True
|
||||
return False
|
||||
@ -115,53 +129,73 @@ class BaiduVector(BaseVector):
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
# Escape double quotes in value to prevent injection
|
||||
escaped_value = value.replace('"', '\\"')
|
||||
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
filter = f'metadata["document_id"] IN({document_ids})'
|
||||
anns = AnnSearch(
|
||||
vector_field=VDBField.VECTOR,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)),
|
||||
filter=filter,
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
retrieve_vector=True,
|
||||
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY],
|
||||
retrieve_vector=False,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
# document ids filter
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] IN({document_ids})'
|
||||
|
||||
request = BM25SearchRequest(
|
||||
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter
|
||||
)
|
||||
res = self._db.table(self._collection_name).bm25_search(
|
||||
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY]
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def _get_search_res(self, res, score_threshold) -> list[Document]:
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
meta = row_data.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
meta = row_data.get(VDBField.METADATA_KEY, {})
|
||||
|
||||
# Handle both JSON string and dict formats for backward compatibility
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
meta = json.loads(meta)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
meta = {}
|
||||
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
@ -178,7 +212,7 @@ class BaiduVector(BaseVector):
|
||||
client = MochowClient(config)
|
||||
return client
|
||||
|
||||
def _init_database(self):
|
||||
def _init_database(self) -> Database:
|
||||
exists = False
|
||||
for db in self._client.list_databases():
|
||||
if db.database_name == self._client_config.database:
|
||||
@ -192,10 +226,10 @@ class BaiduVector(BaseVector):
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
pass
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
raise
|
||||
return
|
||||
return self._client.database(self._client_config.database)
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
@ -232,7 +266,7 @@ class BaiduVector(BaseVector):
|
||||
fields = []
|
||||
fields.append(
|
||||
Field(
|
||||
self.field_id,
|
||||
VDBField.PRIMARY_KEY,
|
||||
FieldType.STRING,
|
||||
primary_key=True,
|
||||
partition_key=True,
|
||||
@ -240,24 +274,57 @@ class BaiduVector(BaseVector):
|
||||
not_null=True,
|
||||
)
|
||||
)
|
||||
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
|
||||
fields.append(Field(self.field_app_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_annotation_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
|
||||
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False))
|
||||
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False))
|
||||
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
|
||||
# Construct vector index params
|
||||
indexes = []
|
||||
indexes.append(
|
||||
VectorIndex(
|
||||
index_name="vector_idx",
|
||||
index_name=self.vector_index,
|
||||
index_type=index_type,
|
||||
field="vector",
|
||||
field=VDBField.VECTOR,
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
)
|
||||
)
|
||||
|
||||
# Filtering index
|
||||
indexes.append(
|
||||
FilteringIndex(
|
||||
index_name=self.filtering_index,
|
||||
fields=[VDBField.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
# Get analyzer and parse_mode from config
|
||||
analyzer = getattr(
|
||||
InvertedIndexAnalyzer,
|
||||
self._client_config.inverted_index_analyzer,
|
||||
InvertedIndexAnalyzer.DEFAULT_ANALYZER,
|
||||
)
|
||||
|
||||
parse_mode = getattr(
|
||||
InvertedIndexParseMode,
|
||||
self._client_config.inverted_index_parser_mode,
|
||||
InvertedIndexParseMode.COARSE_MODE,
|
||||
)
|
||||
|
||||
# Inverted index
|
||||
indexes.append(
|
||||
InvertedIndex(
|
||||
index_name=self.inverted_index,
|
||||
fields=[VDBField.CONTENT_KEY],
|
||||
params=InvertedIndexParams(
|
||||
analyzer=analyzer,
|
||||
parse_mode=parse_mode,
|
||||
case_sensitive=True,
|
||||
),
|
||||
field_attributes=[InvertedIndexFieldAttribute.ANALYZED],
|
||||
)
|
||||
)
|
||||
|
||||
# Create table
|
||||
self._db.create_table(
|
||||
table_name=self._collection_name,
|
||||
@ -268,11 +335,15 @@ class BaiduVector(BaseVector):
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
timeout = 300 # 5 minutes timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
@ -296,5 +367,7 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
|
||||
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
|
||||
),
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
|
||||
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
|
||||
from sqlalchemy import JSON, Column, String
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
@ -117,22 +117,39 @@ class OceanBaseVector(BaseVector):
|
||||
columns=cols,
|
||||
vidxs=vidx_params,
|
||||
)
|
||||
try:
|
||||
if self._hybrid_search_enabled:
|
||||
self._client.create_fts_idx_with_fts_index_param(
|
||||
table_name=self._collection_name,
|
||||
fts_idx_param=FtsIndexParam(
|
||||
index_name="fulltext_index_for_col_text",
|
||||
field_names=["text"],
|
||||
parser_type=FtsParser.IK,
|
||||
),
|
||||
logger.debug("DEBUG: Table '%s' created successfully", self._collection_name)
|
||||
|
||||
if self._hybrid_search_enabled:
|
||||
# Get parser from config or use default ik parser
|
||||
parser_name = dify_config.OCEANBASE_FULLTEXT_PARSER or "ik"
|
||||
|
||||
allowed_parsers = ["ngram", "beng", "space", "ngram2", "ik", "japanese_ftparser", "thai_ftparser"]
|
||||
if parser_name not in allowed_parsers:
|
||||
raise ValueError(
|
||||
f"Invalid OceanBase full-text parser: {parser_name}. "
|
||||
f"Allowed values are: {', '.join(allowed_parsers)}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
|
||||
+ "to support fulltext index and vector index in the same table",
|
||||
e,
|
||||
logger.debug("Hybrid search is enabled, parser_name='%s'", parser_name)
|
||||
logger.debug(
|
||||
"About to create fulltext index for collection '%s' using parser '%s'",
|
||||
self._collection_name,
|
||||
parser_name,
|
||||
)
|
||||
try:
|
||||
sql_command = f"""ALTER TABLE {self._collection_name}
|
||||
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER {parser_name}"""
|
||||
logger.debug("DEBUG: Executing SQL: %s", sql_command)
|
||||
self._client.perform_raw_text_sql(sql_command)
|
||||
logger.debug("DEBUG: Fulltext index created successfully for '%s'", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while creating fulltext index")
|
||||
raise Exception(
|
||||
"Failed to add fulltext index to the target table, your OceanBase version must be "
|
||||
"4.3.5.1 or above to support fulltext index and vector index in the same table"
|
||||
) from e
|
||||
else:
|
||||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@ -229,7 +246,7 @@ class OceanBaseVector(BaseVector):
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Invalid JSON metadata: {metadata_str}")
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=_text, metadata=metadata))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import array
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
@ -19,6 +20,8 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
|
||||
@ -180,8 +183,8 @@ class OracleVector(BaseVector):
|
||||
value,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except Exception:
|
||||
logger.exception("Failed to insert record %s into %s", value[0], self.table_name)
|
||||
conn.close()
|
||||
return pks
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@ -23,6 +24,8 @@ from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
@ -187,8 +190,8 @@ class RelytVector(BaseVector):
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
conn.execute(chunks_table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
except Exception:
|
||||
logger.exception("Delete operation failed for collection %s", self._collection_name)
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
@ -164,8 +164,8 @@ class TiDBVector(BaseVector):
|
||||
delete_condition = table.c.id.in_(ids)
|
||||
conn.execute(table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
except Exception:
|
||||
logger.exception("Delete operation failed for collection %s", self._collection_name)
|
||||
return False
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
@ -417,12 +417,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
if db_model is not None:
|
||||
offload_data = db_model.offload_data
|
||||
|
||||
else:
|
||||
db_model = self._to_db_model(domain_model)
|
||||
offload_data = []
|
||||
offload_data = db_model.offload_data
|
||||
|
||||
offload_data = db_model.offload_data
|
||||
if domain_model.inputs is not None:
|
||||
result = self._truncate_and_upload(
|
||||
domain_model.inputs,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
@ -8,6 +9,8 @@ from typing import Any, ClassVar, Optional
|
||||
class SchemaRegistry:
|
||||
"""Schema registry manages JSON schemas with version support"""
|
||||
|
||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
@ -83,7 +86,7 @@ class SchemaRegistry:
|
||||
self.metadata[uri] = metadata
|
||||
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
self.logger.warning("Failed to load schema %s/%s: %s", version, schema_name, e)
|
||||
|
||||
def get_schema(self, uri: str) -> Any | None:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
|
||||
@ -4,8 +4,7 @@ from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
|
||||
@ -4,11 +4,11 @@ from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import (
|
||||
CredentialType,
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderEntity,
|
||||
|
||||
@ -396,6 +396,10 @@ class ApiTool(Tool):
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
|
||||
# FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
|
||||
# We need never break the original flows
|
||||
yield self.create_text_message(response.text)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
|
||||
@ -5,10 +5,9 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
|
||||
@ -477,4 +477,37 @@ class ToolSelector(BaseModel):
|
||||
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = auto()
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
@ -21,7 +21,6 @@ from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -35,6 +34,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
|
||||
@ -1,24 +1,137 @@
|
||||
# Import generic components from provider_encryption module
|
||||
from core.helper.provider_encryption import (
|
||||
ProviderConfigCache,
|
||||
ProviderConfigEncrypter,
|
||||
create_provider_encrypter,
|
||||
)
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ProviderConfigCache",
|
||||
"ProviderConfigEncrypter",
|
||||
"create_provider_encrypter",
|
||||
"create_tool_provider_encrypter",
|
||||
]
|
||||
|
||||
# Tool-specific imports
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> dict | None:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]):
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self):
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(
|
||||
tenant_id: str, controller: ToolProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# Core trigger module initialization
|
||||
@ -1,76 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.entities.entities import (
|
||||
SubscriptionSchema,
|
||||
TriggerCreationMethod,
|
||||
TriggerDescription,
|
||||
TriggerIdentity,
|
||||
TriggerParameter,
|
||||
)
|
||||
|
||||
|
||||
class TriggerProviderSubscriptionApiEntity(BaseModel):
|
||||
id: str = Field(description="The unique id of the subscription")
|
||||
name: str = Field(description="The name of the subscription")
|
||||
provider: str = Field(description="The provider id of the subscription")
|
||||
credential_type: CredentialType = Field(description="The type of the credential")
|
||||
credentials: dict = Field(description="The credentials of the subscription")
|
||||
endpoint: str = Field(description="The endpoint of the subscription")
|
||||
parameters: dict = Field(description="The parameters of the subscription")
|
||||
properties: dict = Field(description="The properties of the subscription")
|
||||
workflows_in_use: int = Field(description="The number of workflows using this subscription")
|
||||
|
||||
|
||||
class TriggerApiEntity(BaseModel):
|
||||
name: str = Field(description="The name of the trigger")
|
||||
identity: TriggerIdentity = Field(description="The identity of the trigger")
|
||||
description: TriggerDescription = Field(description="The description of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
|
||||
class TriggerProviderApiEntity(BaseModel):
|
||||
author: str = Field(..., description="The author of the trigger provider")
|
||||
name: str = Field(..., description="The name of the trigger provider")
|
||||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
|
||||
supported_creation_methods: list[TriggerCreationMethod] = Field(
|
||||
default_factory=list,
|
||||
description="Supported creation methods for the trigger provider. Possible values: 'OAUTH', 'APIKEY', 'MANUAL'."
|
||||
)
|
||||
|
||||
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||
oauth_client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth client"
|
||||
)
|
||||
subscription_schema: Optional[SubscriptionSchema] = Field(
|
||||
description="The subscription schema of the trigger provider"
|
||||
)
|
||||
triggers: list[TriggerApiEntity] = Field(description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class SubscriptionBuilderApiEntity(BaseModel):
|
||||
id: str = Field(description="The id of the subscription builder")
|
||||
name: str = Field(description="The name of the subscription builder")
|
||||
provider: str = Field(description="The provider id of the subscription builder")
|
||||
endpoint: str = Field(description="The endpoint id of the subscription builder")
|
||||
parameters: Mapping[str, Any] = Field(description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] = Field(description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] = Field(description="The credentials of the subscription builder")
|
||||
credential_type: CredentialType = Field(description="The credential type of the subscription builder")
|
||||
|
||||
|
||||
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]
|
||||
@ -1,309 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import PluginParameterAutoGenerate, PluginParameterOption, PluginParameterTemplate
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class TriggerParameterType(StrEnum):
|
||||
"""The type of the parameter"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
APP_SELECTOR = "app-selector"
|
||||
OBJECT = "object"
|
||||
ARRAY = "array"
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
|
||||
class TriggerParameter(BaseModel):
|
||||
"""
|
||||
The parameter of the trigger
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
type: TriggerParameterType = Field(..., description="The type of the parameter")
|
||||
auto_generate: Optional[PluginParameterAutoGenerate] = Field(
|
||||
default=None, description="The auto generate of the parameter"
|
||||
)
|
||||
template: Optional[PluginParameterTemplate] = Field(default=None, description="The template of the parameter")
|
||||
scope: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
multiple: bool | None = Field(
|
||||
default=False,
|
||||
description="Whether the parameter is multiple select, only valid for select or dynamic-select type",
|
||||
)
|
||||
default: Union[int, float, str, list, None] = None
|
||||
min: Union[float, int, None] = None
|
||||
max: Union[float, int, None] = None
|
||||
precision: Optional[int] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
description: Optional[I18nObject] = None
|
||||
|
||||
|
||||
class TriggerProviderIdentity(BaseModel):
|
||||
"""
|
||||
The identity of the trigger provider
|
||||
"""
|
||||
|
||||
author: str = Field(..., description="The author of the trigger provider")
|
||||
name: str = Field(..., description="The name of the trigger provider")
|
||||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
|
||||
class TriggerIdentity(BaseModel):
|
||||
"""
|
||||
The identity of the trigger
|
||||
"""
|
||||
|
||||
author: str = Field(..., description="The author of the trigger")
|
||||
name: str = Field(..., description="The name of the trigger")
|
||||
label: I18nObject = Field(..., description="The label of the trigger")
|
||||
provider: Optional[str] = Field(default=None, description="The provider of the trigger")
|
||||
|
||||
|
||||
class TriggerDescription(BaseModel):
|
||||
"""
|
||||
The description of the trigger
|
||||
"""
|
||||
|
||||
human: I18nObject = Field(..., description="Human readable description")
|
||||
llm: I18nObject = Field(..., description="LLM readable description")
|
||||
|
||||
|
||||
class TriggerEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger
|
||||
"""
|
||||
|
||||
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
|
||||
description: TriggerDescription = Field(..., description="The description of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(
|
||||
default=None, description="The output schema that this trigger produces"
|
||||
)
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionSchema(BaseModel):
|
||||
"""
|
||||
The subscription schema of the trigger provider
|
||||
"""
|
||||
|
||||
parameters_schema: list[TriggerParameter] | None = Field(
|
||||
default_factory=list,
|
||||
description="The parameters schema required to create a subscription",
|
||||
)
|
||||
|
||||
properties_schema: list[ProviderConfig] | None = Field(
|
||||
default_factory=list,
|
||||
description="The configuration schema stored in the subscription entity",
|
||||
)
|
||||
|
||||
def get_default_parameters(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters from the parameters schema"""
|
||||
if not self.parameters_schema:
|
||||
return {}
|
||||
return {param.name: param.default for param in self.parameters_schema if param.default}
|
||||
|
||||
def get_default_properties(self) -> Mapping[str, Any]:
|
||||
"""Get the default properties from the properties schema"""
|
||||
if not self.properties_schema:
|
||||
return {}
|
||||
return {prop.name: prop.default for prop in self.properties_schema if prop.default}
|
||||
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger provider
|
||||
"""
|
||||
|
||||
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The credentials schema of the trigger provider",
|
||||
)
|
||||
oauth_schema: Optional[OAuthSchema] = Field(
|
||||
default=None,
|
||||
description="The OAuth schema of the trigger provider if OAuth is supported",
|
||||
)
|
||||
subscription_schema: SubscriptionSchema = Field(
|
||||
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
|
||||
)
|
||||
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
"""
|
||||
Result of a successful trigger subscription operation.
|
||||
|
||||
Contains all information needed to manage the subscription lifecycle.
|
||||
"""
|
||||
|
||||
expires_at: int = Field(
|
||||
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
|
||||
)
|
||||
|
||||
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
|
||||
properties: Mapping[str, Any] = Field(
|
||||
..., description="Subscription data containing all properties and provider-specific information"
|
||||
)
|
||||
|
||||
|
||||
class Unsubscription(BaseModel):
|
||||
"""
|
||||
Result of a trigger unsubscription operation.
|
||||
|
||||
Provides detailed information about the unsubscription attempt,
|
||||
including success status and error details if failed.
|
||||
"""
|
||||
|
||||
success: bool = Field(..., description="Whether the unsubscription was successful")
|
||||
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="Human-readable message about the operation result. "
|
||||
"Success message for successful operations, "
|
||||
"detailed error information for failures.",
|
||||
)
|
||||
|
||||
|
||||
class RequestLog(BaseModel):
|
||||
id: str = Field(..., description="The id of the request log")
|
||||
endpoint: str = Field(..., description="The endpoint of the request log")
|
||||
request: dict = Field(..., description="The request of the request log")
|
||||
response: dict = Field(..., description="The response of the request log")
|
||||
created_at: datetime = Field(..., description="The created at of the request log")
|
||||
|
||||
|
||||
class SubscriptionBuilder(BaseModel):
|
||||
id: str = Field(..., description="The id of the subscription builder")
|
||||
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||
tenant_id: str = Field(..., description="The tenant id of the subscription builder")
|
||||
user_id: str = Field(..., description="The user id of the subscription builder")
|
||||
provider_id: str = Field(..., description="The provider id of the subscription builder")
|
||||
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
|
||||
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] = Field(..., description="The credentials of the subscription builder")
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
credential_expires_at: int | None = Field(
|
||||
default=None, description="The credential expires at of the subscription builder"
|
||||
)
|
||||
expires_at: int = Field(..., description="The expires at of the subscription builder")
|
||||
|
||||
def to_subscription(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint_id,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionBuilderUpdater(BaseModel):
|
||||
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] | None = Field(
|
||||
default=None, description="The credentials of the subscription builder"
|
||||
)
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
credential_expires_at: int | None = Field(
|
||||
default=None, description="The credential expires at of the subscription builder"
|
||||
)
|
||||
expires_at: int | None = Field(default=None, description="The expires at of the subscription builder")
|
||||
|
||||
def update(self, subscription_builder: SubscriptionBuilder) -> None:
|
||||
if self.name:
|
||||
subscription_builder.name = self.name
|
||||
if self.parameters:
|
||||
subscription_builder.parameters = self.parameters
|
||||
if self.properties:
|
||||
subscription_builder.properties = self.properties
|
||||
if self.credentials:
|
||||
subscription_builder.credentials = self.credentials
|
||||
if self.credential_type:
|
||||
subscription_builder.credential_type = self.credential_type
|
||||
if self.credential_expires_at:
|
||||
subscription_builder.credential_expires_at = self.credential_expires_at
|
||||
if self.expires_at:
|
||||
subscription_builder.expires_at = self.expires_at
|
||||
|
||||
|
||||
class TriggerEventData(BaseModel):
|
||||
"""Event data dispatched to trigger sessions."""
|
||||
|
||||
subscription_id: str
|
||||
triggers: list[str]
|
||||
request_id: str
|
||||
timestamp: float
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class TriggerInputs(BaseModel):
|
||||
"""Standard inputs for trigger nodes."""
|
||||
|
||||
request_id: str
|
||||
trigger_name: str
|
||||
subscription_id: str
|
||||
|
||||
@classmethod
|
||||
def from_trigger_entity(cls, request_id: str, subscription_id: str, trigger: TriggerEntity) -> "TriggerInputs":
|
||||
"""Create from trigger entity (for production)."""
|
||||
return cls(request_id=request_id, trigger_name=trigger.identity.name, subscription_id=subscription_id)
|
||||
|
||||
def to_workflow_args(self) -> dict[str, Any]:
|
||||
"""Convert to workflow arguments format."""
|
||||
return {"inputs": self.model_dump(), "files": []}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dict (alias for model_dump)."""
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class TriggerCreationMethod(StrEnum):
|
||||
OAUTH = "OAUTH"
|
||||
APIKEY = "APIKEY"
|
||||
MANUAL = "MANUAL"
|
||||
|
||||
|
||||
# Export all entities
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
"RequestLog",
|
||||
"Subscription",
|
||||
"SubscriptionBuilder",
|
||||
"TriggerCreationMethod",
|
||||
"TriggerDescription",
|
||||
"TriggerEntity",
|
||||
"TriggerEventData",
|
||||
"TriggerIdentity",
|
||||
"TriggerInputs",
|
||||
"TriggerParameter",
|
||||
"TriggerParameterType",
|
||||
"TriggerProviderEntity",
|
||||
"TriggerProviderIdentity",
|
||||
"Unsubscription",
|
||||
]
|
||||
@ -1,2 +0,0 @@
|
||||
class TriggerProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
@ -1,358 +0,0 @@
|
||||
"""
|
||||
Trigger Provider Controller for managing trigger providers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import (
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
)
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerCreationMethod,
|
||||
TriggerEntity,
|
||||
TriggerProviderEntity,
|
||||
TriggerProviderIdentity,
|
||||
Unsubscription,
|
||||
)
|
||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginTriggerProviderController:
|
||||
"""
|
||||
Controller for plugin trigger providers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: TriggerProviderEntity,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
provider_id: TriggerProviderID,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Initialize plugin trigger provider controller
|
||||
|
||||
:param entity: Trigger provider entity
|
||||
:param plugin_id: Plugin ID
|
||||
:param plugin_unique_identifier: Plugin unique identifier
|
||||
:param provider_id: Provider ID
|
||||
:param tenant_id: Tenant ID
|
||||
"""
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.provider_id = provider_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_provider_id(self) -> TriggerProviderID:
|
||||
"""
|
||||
Get provider ID
|
||||
"""
|
||||
return self.provider_id
|
||||
|
||||
def to_api_entity(self) -> TriggerProviderApiEntity:
|
||||
"""
|
||||
Convert to API entity
|
||||
"""
|
||||
icon = (
|
||||
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon)
|
||||
if self.entity.identity.icon
|
||||
else None
|
||||
)
|
||||
icon_dark = (
|
||||
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon_dark)
|
||||
if self.entity.identity.icon_dark
|
||||
else None
|
||||
)
|
||||
supported_creation_methods = []
|
||||
if self.entity.oauth_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.OAUTH)
|
||||
if self.entity.credentials_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.APIKEY)
|
||||
if self.entity.subscription_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.MANUAL)
|
||||
return TriggerProviderApiEntity(
|
||||
author=self.entity.identity.author,
|
||||
name=self.entity.identity.name,
|
||||
label=self.entity.identity.label,
|
||||
description=self.entity.identity.description,
|
||||
icon=icon,
|
||||
icon_dark=icon_dark,
|
||||
tags=self.entity.identity.tags,
|
||||
plugin_id=self.plugin_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
credentials_schema=self.entity.credentials_schema,
|
||||
oauth_client_schema=self.entity.oauth_schema.client_schema if self.entity.oauth_schema else [],
|
||||
subscription_schema=self.entity.subscription_schema,
|
||||
supported_creation_methods=supported_creation_methods,
|
||||
triggers=[
|
||||
TriggerApiEntity(
|
||||
name=trigger.identity.name,
|
||||
identity=trigger.identity,
|
||||
description=trigger.description,
|
||||
parameters=trigger.parameters,
|
||||
output_schema=trigger.output_schema,
|
||||
)
|
||||
for trigger in self.entity.triggers
|
||||
],
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self) -> TriggerProviderIdentity:
|
||||
"""Get provider identity"""
|
||||
return self.entity.identity
|
||||
|
||||
def get_triggers(self) -> list[TriggerEntity]:
|
||||
"""
|
||||
Get all triggers for this provider
|
||||
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
return self.entity.triggers
|
||||
|
||||
def get_trigger(self, trigger_name: str) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger by name
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
for trigger in self.entity.triggers:
|
||||
if trigger.identity.name == trigger_name:
|
||||
return trigger
|
||||
return None
|
||||
|
||||
def get_subscription_schema(self) -> SubscriptionSchema:
|
||||
"""
|
||||
Get subscription schema for this provider
|
||||
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return self.entity.subscription_schema
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None:
|
||||
"""
|
||||
Validate credentials against schema
|
||||
|
||||
:param credentials: Credentials to validate
|
||||
:return: Validation response
|
||||
"""
|
||||
# First validate against schema
|
||||
for config in self.entity.credentials_schema:
|
||||
if config.required and config.name not in credentials:
|
||||
raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}")
|
||||
|
||||
# Then validate with the plugin daemon
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
response = manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
)
|
||||
if not response:
|
||||
raise TriggerProviderCredentialValidationError(
|
||||
"Invalid credentials",
|
||||
)
|
||||
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
Get supported credential types for this provider.
|
||||
|
||||
:return: List of supported credential types
|
||||
"""
|
||||
types = []
|
||||
if self.entity.oauth_schema:
|
||||
types.append(CredentialType.OAUTH2)
|
||||
if self.entity.credentials_schema:
|
||||
types.append(CredentialType.API_KEY)
|
||||
return types
|
||||
|
||||
def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get credentials schema by credential type
|
||||
|
||||
:param credential_type: The type of credential (oauth or api_key)
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
Get credential schema config by credential type
|
||||
"""
|
||||
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get OAuth client schema for this provider
|
||||
|
||||
:return: List of OAuth client config schemas
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_properties_schema(self) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
Get properties schema for this provider
|
||||
|
||||
:return: List of properties config schemas
|
||||
"""
|
||||
return (
|
||||
[x.to_basic_provider_config() for x in self.entity.subscription_schema.properties_schema.copy()]
|
||||
if self.entity.subscription_schema.properties_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param request: Flask request object
|
||||
:param subscription: Subscription
|
||||
:return: Dispatch response with triggers and raw HTTP response
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.dispatch_event(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
subscription=subscription.model_dump(),
|
||||
request=request,
|
||||
)
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
user_id: str,
|
||||
trigger_name: str,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Execute a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:param credential_type: Credential type
|
||||
:param request: Request
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
return manager.invoke_trigger(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
trigger=trigger_name,
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
request=request,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
def subscribe_trigger(
|
||||
self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str]
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param endpoint: Subscription endpoint
|
||||
:param subscription_params: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.subscribe(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
endpoint=endpoint,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
def unsubscribe_trigger(
|
||||
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str]
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param subscription: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.unsubscribe(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
subscription=subscription,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return Unsubscription.model_validate(response.subscription)
|
||||
|
||||
def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription through plugin runtime
|
||||
|
||||
:param subscription: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.refresh(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="system", # System refresh
|
||||
provider=str(provider_id),
|
||||
subscription=subscription,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
|
||||
__all__ = ["PluginTriggerProviderController"]
|
||||
@ -1,254 +0,0 @@
|
||||
"""
|
||||
Trigger Manager for loading and managing trigger providers and triggers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerInvokeResponse
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.entities import (
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerEntity,
|
||||
Unsubscription,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager:
|
||||
"""
|
||||
Manager for trigger providers and triggers
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all plugin trigger providers for a tenant
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of trigger provider controllers
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_entities = manager.fetch_trigger_providers(tenant_id)
|
||||
|
||||
controllers = []
|
||||
for provider in provider_entities:
|
||||
try:
|
||||
controller = PluginTriggerProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
provider_id=TriggerProviderID(provider.provider),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
controllers.append(controller)
|
||||
except Exception:
|
||||
logger.exception("Failed to load trigger provider %s", provider.plugin_id)
|
||||
continue
|
||||
|
||||
return controllers
|
||||
|
||||
@classmethod
|
||||
def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController:
|
||||
"""
|
||||
Get a specific plugin trigger provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: Trigger provider controller or None
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_trigger_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_trigger_providers.set({})
|
||||
contexts.plugin_trigger_providers_lock.set(Lock())
|
||||
|
||||
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
|
||||
provider_id_str = str(provider_id)
|
||||
if provider_id_str in plugin_trigger_providers:
|
||||
return plugin_trigger_providers[provider_id_str]
|
||||
|
||||
with contexts.plugin_trigger_providers_lock.get():
|
||||
# double check
|
||||
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
|
||||
if provider_id_str in plugin_trigger_providers:
|
||||
return plugin_trigger_providers[provider_id_str]
|
||||
|
||||
manager = PluginTriggerManager()
|
||||
provider = manager.fetch_trigger_provider(tenant_id, provider_id)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Trigger provider {provider_id} not found")
|
||||
|
||||
try:
|
||||
controller = PluginTriggerProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
plugin_trigger_providers[provider_id_str] = controller
|
||||
return controller
|
||||
except Exception as e:
|
||||
logger.exception("Failed to load trigger provider")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all trigger providers (plugin)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of all trigger provider controllers
|
||||
"""
|
||||
return cls.list_plugin_trigger_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerEntity]:
|
||||
"""
|
||||
List all triggers for a specific provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.get_triggers()
|
||||
|
||||
@classmethod
|
||||
def get_trigger(cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Execute a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:param credential_type: Credential type
|
||||
:param request: Request
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
trigger = provider.get_trigger(trigger_name)
|
||||
if not trigger:
|
||||
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
|
||||
return provider.invoke_trigger(user_id, trigger_name, parameters, credentials, credential_type, request)
|
||||
|
||||
@classmethod
|
||||
def subscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger (e.g., register webhook)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param endpoint: Subscription endpoint
|
||||
:param parameters: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.subscribe_trigger(
|
||||
user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unsubscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param subscription: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> SubscriptionSchema:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||
|
||||
@classmethod
|
||||
def refresh_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials)
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = ["TriggerManager"]
|
||||
@ -1,145 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
|
||||
from core.helper.provider_cache import ProviderCredentialsCache
|
||||
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.trigger import TriggerSubscription
|
||||
|
||||
|
||||
class TriggerProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider OAuth client"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||
|
||||
|
||||
class TriggerProviderPropertiesCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider properties"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
subscription_id = kwargs["subscription_id"]
|
||||
return f"trigger_properties:tenant_id:{tenant_id}:provider_id:{provider_id}:subscription_id:{subscription_id}"
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=subscription.id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(subscription.credential_type),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_id=subscription_id,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderPropertiesCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_properties_schema(),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential_id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(credential_type),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_oauth_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderOAuthClientParamsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def masked_credentials(
|
||||
schemas: list[ProviderConfig],
|
||||
credentials: Mapping[str, str],
|
||||
) -> Mapping[str, str]:
|
||||
masked_credentials = {}
|
||||
configs = {x.name: x.to_basic_provider_config() for x in schemas}
|
||||
for key, value in credentials.items():
|
||||
config = configs.get(key)
|
||||
if not config:
|
||||
masked_credentials[key] = value
|
||||
continue
|
||||
if config.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if len(value) <= 4:
|
||||
masked_credentials[key] = "*" * len(value)
|
||||
else:
|
||||
masked_credentials[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
masked_credentials[key] = value
|
||||
return masked_credentials
|
||||
@ -1,5 +0,0 @@
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def parse_endpoint_id(endpoint_id: str) -> str:
|
||||
return f"{dify_config.CONSOLE_API_URL}/triggers/plugin/{endpoint_id}"
|
||||
@ -58,18 +58,6 @@ class NodeType(StrEnum):
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
|
||||
@property
|
||||
def is_start_node(self) -> bool:
|
||||
return self in [
|
||||
NodeType.START,
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
]
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
@ -134,7 +122,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
||||
@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
@ -103,7 +104,8 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
@ -172,6 +174,7 @@ class GraphExecution:
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
@ -195,6 +198,7 @@ class GraphExecution:
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
@ -205,3 +209,7 @@ class GraphExecution:
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
@ -3,11 +3,12 @@ Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@ -122,13 +123,15 @@ class EventHandler:
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
@ -161,7 +164,7 @@ class EventHandler:
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
@ -191,7 +194,7 @@ class EventHandler:
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
@ -207,6 +210,7 @@ class EventHandler:
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
@ -227,10 +231,40 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
@ -242,21 +276,31 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
|
||||
@ -5,6 +5,7 @@ Unified event manager for collecting and emitting events.
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
@ -51,43 +52,23 @@ class ReadWriteLock:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
def read_lock(self) -> "ReadLockContext":
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
"""Return a context manager for read locking."""
|
||||
return ReadLockContext(self)
|
||||
self.acquire_read()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_read()
|
||||
|
||||
def write_lock(self) -> "WriteLockContext":
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
"""Return a context manager for write locking."""
|
||||
return WriteLockContext(self)
|
||||
|
||||
|
||||
@final
|
||||
class ReadLockContext:
|
||||
"""Context manager for read locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "ReadLockContext":
|
||||
self._lock.acquire_read()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_read()
|
||||
|
||||
|
||||
@final
|
||||
class WriteLockContext:
|
||||
"""Context manager for write locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "WriteLockContext":
|
||||
self._lock.acquire_write()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_write()
|
||||
self.acquire_write()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@ -23,6 +23,7 @@ from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
@ -260,12 +261,23 @@ class GraphEngine:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
yield GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
@ -15,6 +15,7 @@ from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
@ -138,6 +146,12 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
# Retry before Started because Retry subclasses Started;
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
@ -167,11 +181,6 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
|
||||
@ -147,4 +147,4 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command: %s")
|
||||
self.logger.exception("Failed to send abort command")
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
@ -19,6 +21,7 @@ from core.workflow.enums import (
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
@ -34,6 +37,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -238,6 +242,8 @@ class IterationNode(Node):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@ -280,26 +286,29 @@ class IterationNode(Node):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
@ -372,43 +381,16 @@ class IterationNode(Node):
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
@ -444,9 +426,7 @@ class IterationNode(Node):
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -485,7 +465,7 @@ class IterationNode(Node):
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
|
||||
@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel):
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
@ -127,11 +128,13 @@ class LoopNode(Node):
|
||||
try:
|
||||
reach_break_condition = False
|
||||
if break_conditions:
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
with contextlib.suppress(ValueError):
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
cost_tokens = 0
|
||||
@ -295,42 +298,11 @@ class LoopNode(Node):
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
loop_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
@ -371,12 +343,35 @@ class LoopNode(Node):
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
This method statically analyzes the graph configuration to find all nodes
|
||||
that are part of the specified loop, without creating actual node instances.
|
||||
|
||||
:param graph_config: the complete graph configuration
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
loop_node_ids.add(node_id)
|
||||
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
|
||||
@ -21,9 +21,6 @@ from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.trigger_plugin import TriggerPluginNode
|
||||
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
||||
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
@ -145,16 +142,4 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.TRIGGER_WEBHOOK: {
|
||||
LATEST_VERSION: TriggerWebhookNode,
|
||||
"1": TriggerWebhookNode,
|
||||
},
|
||||
NodeType.TRIGGER_PLUGIN: {
|
||||
LATEST_VERSION: TriggerPluginNode,
|
||||
"1": TriggerPluginNode,
|
||||
},
|
||||
NodeType.TRIGGER_SCHEDULE: {
|
||||
LATEST_VERSION: TriggerScheduleNode,
|
||||
"1": TriggerScheduleNode,
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from .trigger_plugin_node import TriggerPluginNode
|
||||
|
||||
__all__ = ["TriggerPluginNode"]
|
||||
@ -1,28 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
|
||||
|
||||
class PluginTriggerData(BaseNodeData):
|
||||
"""Plugin trigger node data"""
|
||||
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
plugin_id: str = Field(..., description="Plugin ID")
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
trigger_name: str = Field(..., description="Trigger name")
|
||||
subscription_id: str = Field(..., description="Subscription ID")
|
||||
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
|
||||
parameters: dict[str, Any] = Field(default_factory=dict, description="Trigger parameters")
|
||||
|
||||
# Error handling
|
||||
error_strategy: Optional[ErrorStrategy] = Field(
|
||||
default=ErrorStrategy.FAIL_BRANCH, description="Error handling strategy"
|
||||
)
|
||||
retry_config: RetryConfig = Field(default_factory=lambda: RetryConfig(), description="Retry configuration")
|
||||
default_value_dict: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default values for outputs when error occurs"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user