mirror of
https://github.com/langgenius/dify.git
synced 2026-01-21 12:35:21 +08:00
Compare commits
10 Commits
feat/trigg
...
pinecone
| Author | SHA1 | Date | |
|---|---|---|---|
| 594906c1ff | |||
| 80f8245f2e | |||
| a12b437c16 | |||
| 12de554313 | |||
| 1f36c0c1c5 | |||
| 8b9297563c | |||
| 1cbe9eedb6 | |||
| 90fc5a1f12 | |||
| 41dfdf1ac0 | |||
| dd7de74aa6 |
@ -1,6 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -2,8 +2,6 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
|
||||
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -89,9 +89,7 @@ jobs:
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
run: pnpm run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -218,6 +218,3 @@ mise.toml
|
||||
.roo/
|
||||
api/.env.backup
|
||||
/clickzetta
|
||||
|
||||
# mcp
|
||||
.serena
|
||||
@ -59,7 +59,6 @@ pnpm test # Run Jest tests
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
|
||||
@ -156,7 +156,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `pinecone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -361,6 +361,17 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
|
||||
# Pinecone configuration, only available when VECTOR_STORE is `pinecone`
|
||||
PINECONE_API_KEY=your-pinecone-api-key
|
||||
PINECONE_ENVIRONMENT=your-pinecone-environment
|
||||
PINECONE_INDEX_NAME=dify-index
|
||||
PINECONE_CLIENT_TIMEOUT=30
|
||||
PINECONE_BATCH_SIZE=100
|
||||
PINECONE_METRIC=cosine
|
||||
PINECONE_PODS=1
|
||||
PINECONE_POD_TYPE=s1
|
||||
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
MAIL_TYPE=
|
||||
# If using SendGrid, use the 'from' field for authentication if necessary.
|
||||
@ -434,9 +445,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -505,12 +513,6 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -1207,55 +1207,6 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
|
||||
@ -147,17 +147,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@ -882,22 +871,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -1021,7 +994,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
|
||||
@ -35,6 +35,7 @@ from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.pinecone_config import PineconeConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tablestore_config import TableStoreConfig
|
||||
@ -331,6 +332,7 @@ class MiddlewareConfig(
|
||||
PGVectorConfig,
|
||||
VastbaseVectorConfig,
|
||||
PGVectoRSConfig,
|
||||
PineconeConfig,
|
||||
QdrantConfig,
|
||||
RelytConfig,
|
||||
TencentVectorDBConfig,
|
||||
|
||||
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PineconeConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Pinecone vector database
|
||||
"""
|
||||
|
||||
PINECONE_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with Pinecone service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_ENVIRONMENT: Optional[str] = Field(
|
||||
description="Pinecone environment (e.g., 'us-west1-gcp', 'us-east-1-aws')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_INDEX_NAME: Optional[str] = Field(
|
||||
description="Default Pinecone index name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_CLIENT_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for Pinecone client operations (default is 30 seconds)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
PINECONE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Batch size for Pinecone operations (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
PINECONE_METRIC: str = Field(
|
||||
description="Distance metric for Pinecone index (cosine, euclidean, dotproduct)",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -67,7 +67,6 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -181,6 +180,5 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@ -12,7 +12,6 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -126,11 +125,13 @@ class InstructionGenerateApi(Resource):
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
|
||||
@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@ -39,7 +38,6 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.trigger_debug_service import TriggerDebugService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -808,132 +806,6 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
return node_exec
|
||||
|
||||
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
try:
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
user_inputs = event.model_dump()
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query="",
|
||||
files=[],
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
workflow_args = {
|
||||
"inputs": event.model_dump(),
|
||||
"query": "",
|
||||
"files": [],
|
||||
}
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
workflow_args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
@ -958,14 +830,6 @@ api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowTriggerNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowTriggerRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/trigger/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
|
||||
@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
||||
@ -1,249 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account, AppMode
|
||||
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
|
||||
class PluginTriggerApi(Resource):
|
||||
"""Workflow Plugin Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def post(self, app_model):
|
||||
"""Create plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=False, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=False, location="json")
|
||||
parser.add_argument("trigger_name", type=str, required=False, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args["provider_id"],
|
||||
trigger_name=args["trigger_name"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def get(self, app_model):
|
||||
"""Get plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def put(self, app_model):
|
||||
"""Update plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def delete(self, app_model):
|
||||
"""Delete plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = args["node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.filter(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
# Add computed fields for marshal_with
|
||||
base_url = dify_config.SERVICE_API_URL
|
||||
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
|
||||
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -660,6 +660,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.PINECONE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -711,6 +712,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.PINECONE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -516,20 +516,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
@ -1,589 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if there's a system OAuth client
|
||||
system_client = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_configured": bool(custom_params),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params if custom_params else {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@ -9,9 +9,10 @@ from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
@ -194,6 +195,50 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
) -> EndUser:
|
||||
"""Create end user in existing session"""
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=mcp_server_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush() # Use flush instead of commit to keep transaction open
|
||||
session.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
def _handle_mcp_request(
|
||||
self,
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
@ -1,41 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
raise NotFound("Endpoint not found")
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@ -1,46 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
# Get webhook trigger, workflow, and node configuration
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||
|
||||
# Extract request data
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate request against node configuration
|
||||
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
|
||||
if not validation_result["valid"]:
|
||||
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
TASK_PIPELINE = auto()
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
|
||||
@ -54,8 +54,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -70,8 +68,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -86,8 +82,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@ -101,8 +95,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -138,20 +130,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||
# start node get inputs
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@ -170,10 +159,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@ -201,7 +187,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -217,7 +202,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -255,7 +239,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
@ -452,7 +435,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -496,7 +478,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -34,7 +34,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -45,7 +44,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
@ -95,7 +93,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
@ -79,7 +79,7 @@ class WorkflowBasedAppRunner:
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
@ -92,7 +92,7 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
@ -472,10 +472,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
||||
@ -192,9 +192,8 @@ class ProviderConfig(BasicProviderConfig):
|
||||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str, float, bool, list]] = None
|
||||
default: Optional[Union[int, str, float, bool]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
multiple: bool | None = False
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
@ -3,7 +3,7 @@ import base64
|
||||
from libs import rsa
|
||||
|
||||
|
||||
def obfuscated_token(token: str) -> str:
|
||||
def obfuscated_token(token: str):
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
@ -158,6 +158,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
@ -186,6 +188,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
@ -210,6 +214,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
@ -231,6 +237,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
@ -261,6 +269,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
@ -285,6 +295,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
@ -306,6 +318,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
@ -329,6 +343,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
@ -388,6 +404,8 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
@ -13,7 +13,6 @@ from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(enum.StrEnum):
|
||||
@ -63,7 +62,6 @@ class PluginCategory(enum.StrEnum):
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
AgentStrategy = "agent-strategy"
|
||||
Trigger = "trigger"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
@ -71,7 +69,6 @@ class PluginDeclaration(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
triggers: Optional[list[str]] = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
@ -92,7 +89,6 @@ class PluginDeclaration(BaseModel):
|
||||
repo: Optional[str] = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
trigger: Optional[TriggerProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
@ -108,8 +104,6 @@ class PluginDeclaration(BaseModel):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
values["category"] = PluginCategory.Trigger
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
@ -190,10 +184,6 @@ class ToolProviderID(GenericProviderID):
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class TriggerProviderID(GenericProviderID):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@ -14,7 +13,6 @@ from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
@ -198,48 +196,3 @@ class PluginListResponse(BaseModel):
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
UNAUTHORIZED = "unauthorized"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
@ -239,34 +237,3 @@ class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
variables: Mapping[str, Any]
|
||||
|
||||
|
||||
class TriggerInvokeResponse(BaseModel):
|
||||
event: Event
|
||||
cancelled: Optional[bool] = False
|
||||
|
||||
|
||||
class PluginTriggerDispatchResponse(BaseModel):
|
||||
triggers: list[str]
|
||||
raw_http_response: str
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse:
|
||||
triggers: list[str]
|
||||
response: Response
|
||||
|
||||
def __init__(self, triggers: list[str], response: Response):
|
||||
self.triggers = triggers
|
||||
self.response = response
|
||||
|
||||
@ -15,7 +15,6 @@ class DynamicSelectClient(BasePluginClient):
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
credential_type: str,
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
@ -30,7 +29,6 @@ class DynamicSelectClient(BasePluginClient):
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
|
||||
@ -4,10 +4,10 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
|
||||
@ -1,301 +0,0 @@
|
||||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import (
|
||||
PluginTriggerDispatchResponse,
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
TriggerSubscriptionResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.http_parser import deserialize_response, serialize_request
|
||||
from core.trigger.entities.entities import Subscription
|
||||
|
||||
|
||||
class PluginTriggerManager(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch trigger providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
|
||||
for trigger in declaration.get("triggers", []):
|
||||
trigger["identity"]["provider"] = provider_id
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/triggers",
|
||||
list[PluginTriggerProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in provider.declaration.triggers:
|
||||
trigger.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch trigger provider for the given tenant and plugin.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for trigger in data.get("declaration", {}).get("triggers", []):
|
||||
trigger["identity"]["provider"] = str(provider_id)
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/trigger",
|
||||
PluginTriggerProviderEntity,
|
||||
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = str(provider_id)
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in response.declaration.triggers:
|
||||
trigger.identity.provider = str(provider_id)
|
||||
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
trigger: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/invoke",
|
||||
TriggerInvokeResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"trigger": trigger,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerInvokeResponse(event=resp.event)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Mapping[str, Any],
|
||||
request: Request,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch an event to triggers.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||
PluginTriggerDispatchResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerDispatchResponse(
|
||||
triggers=resp.triggers,
|
||||
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
|
||||
)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: Mapping[str, str],
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
"endpoint": endpoint,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for subscribe")
|
||||
|
||||
def unsubscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for refresh")
|
||||
@ -1,159 +0,0 @@
|
||||
from io import BytesIO
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
method = request.method
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
if "Content-Type" in headers:
|
||||
environ["CONTENT_TYPE"] = headers.get("Content-Type")
|
||||
|
||||
if "Content-Length" in headers:
|
||||
environ["CONTENT_LENGTH"] = headers.get("Content-Length")
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
@ -87,6 +87,7 @@ class PromptMessageUtil:
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
|
||||
@ -2,7 +2,7 @@ import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -154,8 +154,8 @@ class ProviderManager:
|
||||
for provider_entity in provider_entities:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
|
||||
@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
||||
@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
||||
@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
|
||||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
|
||||
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
||||
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
|
||||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -195,7 +195,7 @@ class PGVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
341
api/core/rag/datasource/vdb/pinecone/pinecone_vector.py
Normal file
341
api/core/rag/datasource/vdb/pinecone/pinecone_vector.py
Normal file
@ -0,0 +1,341 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class PineconeConfig(BaseModel):
|
||||
"""Pinecone configuration class"""
|
||||
api_key: str
|
||||
environment: str
|
||||
index_name: Optional[str] = None
|
||||
timeout: float = 30
|
||||
batch_size: int = 100
|
||||
metric: str = "cosine"
|
||||
|
||||
|
||||
class PineconeVector(BaseVector):
|
||||
"""Pinecone vector database concrete implementation class"""
|
||||
|
||||
def __init__(self, collection_name: str, group_id: str, config: PineconeConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._group_id = group_id
|
||||
|
||||
# Initialize Pinecone client with SSL configuration
|
||||
try:
|
||||
self._pc = Pinecone(
|
||||
api_key=config.api_key,
|
||||
# Configure SSL to handle connection issues
|
||||
ssl_ca_certs=None, # Use system default CA certificates
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback to basic initialization if SSL config fails
|
||||
self._pc = Pinecone(api_key=config.api_key)
|
||||
|
||||
# Normalize index name: lowercase, only a-z0-9- and <=45 chars
|
||||
import re, hashlib
|
||||
base_name = collection_name.lower()
|
||||
base_name = re.sub(r'[^a-z0-9-]+', '-', base_name) # replace invalid chars with '-'
|
||||
base_name = re.sub(r'-+', '-', base_name).strip('-')
|
||||
# Use longer secure suffix to reduce collision risk
|
||||
suffix_len = 24 # 24 hex digits (96-bit entropy)
|
||||
if len(base_name) > 45:
|
||||
hash_suffix = hashlib.sha256(base_name.encode()).hexdigest()[:suffix_len]
|
||||
truncated_name = base_name[:45-(suffix_len+1)].rstrip('-')
|
||||
self._index_name = f"{truncated_name}-{hash_suffix}"
|
||||
else:
|
||||
self._index_name = base_name
|
||||
# Guard empty name
|
||||
if not self._index_name:
|
||||
self._index_name = f"index-{hashlib.sha256(collection_name.encode()).hexdigest()[:suffix_len]}"
|
||||
self._index = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Return vector database type identifier"""
|
||||
return "pinecone"
|
||||
|
||||
def _ensure_index_initialized(self) -> None:
|
||||
"""Ensure that self._index is attached to an existing Pinecone index."""
|
||||
if self._index is not None:
|
||||
return
|
||||
try:
|
||||
existing_indexes = self._pc.list_indexes().names()
|
||||
if self._index_name in existing_indexes:
|
||||
self._index = self._pc.Index(self._index_name)
|
||||
else:
|
||||
raise ValueError("Index not initialized. Please ingest documents to create index.")
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
"""Generate index structure dictionary"""
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create vector index"""
|
||||
if texts:
|
||||
# Get vector dimension
|
||||
vector_size = len(embeddings[0])
|
||||
|
||||
# Create Pinecone index
|
||||
self.create_index(vector_size)
|
||||
|
||||
# Add vector data
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_index(self, dimension: int):
|
||||
"""Create Pinecone index"""
|
||||
lock_name = f"vector_indexing_lock_{self._index_name}"
|
||||
|
||||
with redis_client.lock(lock_name, timeout=30):
|
||||
# Check Redis cache
|
||||
index_exist_cache_key = f"vector_indexing_{self._index_name}"
|
||||
if redis_client.get(index_exist_cache_key):
|
||||
self._index = self._pc.Index(self._index_name)
|
||||
return
|
||||
|
||||
# Check if index already exists
|
||||
existing_indexes = self._pc.list_indexes().names()
|
||||
|
||||
if self._index_name not in existing_indexes:
|
||||
# Create new index using ServerlessSpec
|
||||
self._pc.create_index(
|
||||
name=self._index_name,
|
||||
dimension=dimension,
|
||||
metric=self._client_config.metric,
|
||||
spec=ServerlessSpec(
|
||||
cloud='aws',
|
||||
region=self._client_config.environment
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for index creation to complete
|
||||
while not self._pc.describe_index(self._index_name).status['ready']:
|
||||
time.sleep(1)
|
||||
else:
|
||||
# Get index instance
|
||||
self._index = self._pc.Index(self._index_name)
|
||||
|
||||
# Set cache
|
||||
redis_client.set(index_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Batch add document vectors"""
|
||||
if not self._index:
|
||||
raise ValueError("Index not initialized. Call create() first.")
|
||||
|
||||
total_docs = len(documents)
|
||||
|
||||
uuids = self._get_uuids(documents)
|
||||
batch_size = self._client_config.batch_size
|
||||
added_ids = []
|
||||
|
||||
# Batch processing
|
||||
total_batches = (total_docs + batch_size - 1) // batch_size # Ceiling division
|
||||
for batch_idx, i in enumerate(range(0, len(documents), batch_size), 1):
|
||||
batch_documents = documents[i:i + batch_size]
|
||||
batch_embeddings = embeddings[i:i + batch_size]
|
||||
batch_uuids = uuids[i:i + batch_size]
|
||||
batch_size_actual = len(batch_documents)
|
||||
|
||||
# Build Pinecone vector data (metadata must be primitives or list[str])
|
||||
vectors_to_upsert = []
|
||||
for doc, embedding, doc_id in zip(batch_documents, batch_embeddings, batch_uuids):
|
||||
raw_meta = doc.metadata or {}
|
||||
safe_meta: dict[str, Any] = {}
|
||||
# lift common identifiers to top-level fields for filtering
|
||||
for k, v in raw_meta.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
safe_meta[k] = v
|
||||
elif isinstance(v, list) and all(isinstance(x, str) for x in v):
|
||||
safe_meta[k] = v
|
||||
else:
|
||||
safe_meta[k] = json.dumps(v, ensure_ascii=False)
|
||||
|
||||
# keep content as string metadata if needed
|
||||
safe_meta[Field.CONTENT_KEY.value] = doc.page_content
|
||||
# group id as string
|
||||
safe_meta[Field.GROUP_KEY.value] = str(self._group_id)
|
||||
|
||||
vectors_to_upsert.append({
|
||||
"id": doc_id,
|
||||
"values": embedding,
|
||||
"metadata": safe_meta
|
||||
})
|
||||
|
||||
# Batch insert to Pinecone
|
||||
try:
|
||||
self._index.upsert(vectors=vectors_to_upsert)
|
||||
added_ids.extend(batch_uuids)
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
return added_ids
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs) -> list[Document]:
|
||||
"""Vector similarity search"""
|
||||
# Lazily attach to an existing index if needed
|
||||
self._ensure_index_initialized()
|
||||
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold", 0.0))
|
||||
|
||||
# Build filter conditions
|
||||
filter_dict = {Field.GROUP_KEY.value: {"$eq": str(self._group_id)}}
|
||||
|
||||
# Document scope filtering
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter_dict["document_id"] = {"$in": document_ids_filter}
|
||||
|
||||
# Execute search
|
||||
try:
|
||||
response = self._index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
include_metadata=True,
|
||||
filter=filter_dict
|
||||
)
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
# Convert results
|
||||
docs = []
|
||||
filtered_count = 0
|
||||
for match in response.matches:
|
||||
if match.score >= score_threshold:
|
||||
page_content = match.metadata.get(Field.CONTENT_KEY.value, "")
|
||||
metadata = dict(match.metadata or {})
|
||||
metadata.pop(Field.CONTENT_KEY.value, None)
|
||||
metadata.pop(Field.GROUP_KEY.value, None)
|
||||
metadata["score"] = match.score
|
||||
|
||||
doc = Document(page_content=page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
else:
|
||||
filtered_count += 1
|
||||
|
||||
# Sort by similarity score in descending order
|
||||
docs.sort(key=lambda x: x.metadata.get("score", 0), reverse=True)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs) -> list[Document]:
|
||||
"""Full-text search - Pinecone does not natively support it, returns empty list"""
|
||||
return []
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""Delete by metadata field"""
|
||||
self._ensure_index_initialized()
|
||||
|
||||
try:
|
||||
# Build filter conditions
|
||||
filter_dict = {
|
||||
Field.GROUP_KEY.value: {"$eq": self._group_id},
|
||||
f"{Field.METADATA_KEY.value}.{key}": {"$eq": value}
|
||||
}
|
||||
|
||||
# Pinecone delete operation
|
||||
self._index.delete(filter=filter_dict)
|
||||
except Exception as e:
|
||||
# Ignore delete errors
|
||||
pass
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""Batch delete by ID list"""
|
||||
self._ensure_index_initialized()
|
||||
|
||||
try:
|
||||
# Pinecone delete by ID
|
||||
self._index.delete(ids=ids)
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete all vector data for the entire dataset"""
|
||||
self._ensure_index_initialized()
|
||||
|
||||
try:
|
||||
# Delete all vectors by group_id
|
||||
filter_dict = {Field.GROUP_KEY.value: {"$eq": self._group_id}}
|
||||
self._index.delete(filter=filter_dict)
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if document exists"""
|
||||
try:
|
||||
self._ensure_index_initialized()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check if vector exists through query
|
||||
response = self._index.fetch(ids=[id])
|
||||
exists = id in response.vectors
|
||||
return exists
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
|
||||
class PineconeVectorFactory(AbstractVectorFactory):
|
||||
"""Pinecone vector database factory class"""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PineconeVector:
|
||||
"""Create PineconeVector instance"""
|
||||
|
||||
# Determine index name
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError("Dataset Collection Bindings does not exist!")
|
||||
else:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
# Set index structure
|
||||
if not dataset.index_struct_dict:
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict("pinecone", collection_name)
|
||||
)
|
||||
|
||||
# Create PineconeVector instance
|
||||
return PineconeVector(
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=PineconeConfig(
|
||||
api_key=dify_config.PINECONE_API_KEY or "",
|
||||
environment=dify_config.PINECONE_ENVIRONMENT or "",
|
||||
index_name=dify_config.PINECONE_INDEX_NAME,
|
||||
timeout=dify_config.PINECONE_CLIENT_TIMEOUT,
|
||||
batch_size=dify_config.PINECONE_BATCH_SIZE,
|
||||
metric=dify_config.PINECONE_METRIC,
|
||||
),
|
||||
)
|
||||
@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@ -369,7 +369,7 @@ class QdrantVector(BaseVector):
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
@ -426,6 +426,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -233,7 +233,7 @@ class RelytVector(BaseVector):
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
if 1 - score >= score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
|
||||
)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score > score_threshold:
|
||||
if search_hit.score >= score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
@ -291,7 +291,7 @@ class TencentVector(BaseVector):
|
||||
score = 1 - result.get("score", 0.0)
|
||||
else:
|
||||
score = result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
@ -351,7 +351,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
|
||||
@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
|
||||
score = record.score
|
||||
if metadata is not None and text is not None:
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -86,6 +86,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
return PGVectoRSFactory
|
||||
case VectorType.PINECONE:
|
||||
from core.rag.datasource.vdb.pinecone.pinecone_vector import PineconeVectorFactory
|
||||
|
||||
return PineconeVectorFactory
|
||||
case VectorType.QDRANT:
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||
|
||||
|
||||
@ -31,3 +31,4 @@ class VectorType(StrEnum):
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
PINECONE = "pinecone"
|
||||
|
||||
@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -10,6 +10,23 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def _format_cell_value(value) -> str:
|
||||
if pd.isna(value):
|
||||
return ""
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
if isinstance(value, float):
|
||||
if value.is_integer():
|
||||
return str(int(value))
|
||||
else:
|
||||
formatted = f"{value:f}"
|
||||
return formatted.rstrip('0').rstrip('.')
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
@ -49,10 +66,12 @@ class ExcelExtractor(BaseExtractor):
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
formatted_v = _format_cell_value(v)
|
||||
value = f"[{formatted_v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
formatted_v = _format_cell_value(v)
|
||||
page_content.append(f'"{k}":"{formatted_v}"')
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
@ -67,7 +86,8 @@ class ExcelExtractor(BaseExtractor):
|
||||
page_content = []
|
||||
for k, v in row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
formatted_v = _format_cell_value(v)
|
||||
page_content.append(f'"{k}":"{formatted_v}"')
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
|
||||
|
||||
@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return data_source_binding.access_token
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
documents = list(self.load())
|
||||
|
||||
@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -162,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -4,8 +4,7 @@ from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
|
||||
@ -4,11 +4,11 @@ from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import (
|
||||
CredentialType,
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderEntity,
|
||||
|
||||
@ -4,10 +4,9 @@ from typing import Any, Literal, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
|
||||
@ -476,3 +476,36 @@ class ToolSelector(BaseModel):
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
@ -37,7 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
@ -48,6 +47,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
@ -331,13 +331,16 @@ class ToolManager:
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
return cast(
|
||||
WorkflowTool,
|
||||
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
@ -645,8 +648,8 @@ class ToolManager:
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
|
||||
@ -1,23 +1,132 @@
|
||||
# Import generic components from provider_encryption module
|
||||
from core.helper.provider_encryption import (
|
||||
ProviderConfigCache,
|
||||
ProviderConfigEncrypter,
|
||||
create_provider_encrypter,
|
||||
)
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ProviderConfigCache",
|
||||
"ProviderConfigEncrypter",
|
||||
"create_provider_encrypter",
|
||||
"create_tool_provider_encrypter",
|
||||
]
|
||||
|
||||
# Tool-specific imports
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
@ -159,7 +159,8 @@ class ToolFileMessageTransformer:
|
||||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
message.message.json_object = safe_json_value(message.message.json_object)
|
||||
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
|
||||
json_msg.json_object = safe_json_value(json_msg.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@ -129,14 +129,17 @@ class ModelInvocationUtils:
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
response: LLMResult = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
@ -204,14 +204,14 @@ class WorkflowTool(Tool):
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# Core trigger module initialization
|
||||
@ -1,76 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.entities.entities import (
|
||||
SubscriptionSchema,
|
||||
TriggerCreationMethod,
|
||||
TriggerDescription,
|
||||
TriggerIdentity,
|
||||
TriggerParameter,
|
||||
)
|
||||
|
||||
|
||||
class TriggerProviderSubscriptionApiEntity(BaseModel):
|
||||
id: str = Field(description="The unique id of the subscription")
|
||||
name: str = Field(description="The name of the subscription")
|
||||
provider: str = Field(description="The provider id of the subscription")
|
||||
credential_type: CredentialType = Field(description="The type of the credential")
|
||||
credentials: dict = Field(description="The credentials of the subscription")
|
||||
endpoint: str = Field(description="The endpoint of the subscription")
|
||||
parameters: dict = Field(description="The parameters of the subscription")
|
||||
properties: dict = Field(description="The properties of the subscription")
|
||||
workflows_in_use: int = Field(description="The number of workflows using this subscription")
|
||||
|
||||
|
||||
class TriggerApiEntity(BaseModel):
|
||||
name: str = Field(description="The name of the trigger")
|
||||
identity: TriggerIdentity = Field(description="The identity of the trigger")
|
||||
description: TriggerDescription = Field(description="The description of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
|
||||
class TriggerProviderApiEntity(BaseModel):
|
||||
author: str = Field(..., description="The author of the trigger provider")
|
||||
name: str = Field(..., description="The name of the trigger provider")
|
||||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
|
||||
supported_creation_methods: list[TriggerCreationMethod] = Field(
|
||||
default_factory=list,
|
||||
description="Supported creation methods for the trigger provider. like 'OAUTH', 'APIKEY', 'MANUAL'.",
|
||||
)
|
||||
|
||||
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||
oauth_client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth client"
|
||||
)
|
||||
subscription_schema: Optional[SubscriptionSchema] = Field(
|
||||
description="The subscription schema of the trigger provider"
|
||||
)
|
||||
triggers: list[TriggerApiEntity] = Field(description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class SubscriptionBuilderApiEntity(BaseModel):
|
||||
id: str = Field(description="The id of the subscription builder")
|
||||
name: str = Field(description="The name of the subscription builder")
|
||||
provider: str = Field(description="The provider id of the subscription builder")
|
||||
endpoint: str = Field(description="The endpoint id of the subscription builder")
|
||||
parameters: Mapping[str, Any] = Field(description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] = Field(description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] = Field(description="The credentials of the subscription builder")
|
||||
credential_type: CredentialType = Field(description="The credential type of the subscription builder")
|
||||
|
||||
|
||||
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]
|
||||
@ -1,304 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import PluginParameterAutoGenerate, PluginParameterOption, PluginParameterTemplate
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class TriggerParameterType(StrEnum):
|
||||
"""The type of the parameter"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
APP_SELECTOR = "app-selector"
|
||||
OBJECT = "object"
|
||||
ARRAY = "array"
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
|
||||
class TriggerParameter(BaseModel):
|
||||
"""
|
||||
The parameter of the trigger
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
type: TriggerParameterType = Field(..., description="The type of the parameter")
|
||||
auto_generate: Optional[PluginParameterAutoGenerate] = Field(
|
||||
default=None, description="The auto generate of the parameter"
|
||||
)
|
||||
template: Optional[PluginParameterTemplate] = Field(default=None, description="The template of the parameter")
|
||||
scope: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
multiple: bool | None = Field(
|
||||
default=False,
|
||||
description="Whether the parameter is multiple select, only valid for select or dynamic-select type",
|
||||
)
|
||||
default: Union[int, float, str, list, None] = None
|
||||
min: Union[float, int, None] = None
|
||||
max: Union[float, int, None] = None
|
||||
precision: Optional[int] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
description: Optional[I18nObject] = None
|
||||
|
||||
|
||||
class TriggerProviderIdentity(BaseModel):
|
||||
"""
|
||||
The identity of the trigger provider
|
||||
"""
|
||||
|
||||
author: str = Field(..., description="The author of the trigger provider")
|
||||
name: str = Field(..., description="The name of the trigger provider")
|
||||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
|
||||
class TriggerIdentity(BaseModel):
|
||||
"""
|
||||
The identity of the trigger
|
||||
"""
|
||||
|
||||
author: str = Field(..., description="The author of the trigger")
|
||||
name: str = Field(..., description="The name of the trigger")
|
||||
label: I18nObject = Field(..., description="The label of the trigger")
|
||||
provider: Optional[str] = Field(default=None, description="The provider of the trigger")
|
||||
|
||||
|
||||
class TriggerDescription(BaseModel):
|
||||
"""
|
||||
The description of the trigger
|
||||
"""
|
||||
|
||||
human: I18nObject = Field(..., description="Human readable description")
|
||||
llm: I18nObject = Field(..., description="LLM readable description")
|
||||
|
||||
|
||||
class TriggerEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger
|
||||
"""
|
||||
|
||||
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
|
||||
description: TriggerDescription = Field(..., description="The description of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(
|
||||
default=None, description="The output schema that this trigger produces"
|
||||
)
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionSchema(BaseModel):
|
||||
"""
|
||||
The subscription schema of the trigger provider
|
||||
"""
|
||||
|
||||
parameters_schema: list[TriggerParameter] | None = Field(
|
||||
default_factory=list,
|
||||
description="The parameters schema required to create a subscription",
|
||||
)
|
||||
|
||||
properties_schema: list[ProviderConfig] | None = Field(
|
||||
default_factory=list,
|
||||
description="The configuration schema stored in the subscription entity",
|
||||
)
|
||||
|
||||
def get_default_parameters(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters from the parameters schema"""
|
||||
if not self.parameters_schema:
|
||||
return {}
|
||||
return {param.name: param.default for param in self.parameters_schema if param.default}
|
||||
|
||||
def get_default_properties(self) -> Mapping[str, Any]:
|
||||
"""Get the default properties from the properties schema"""
|
||||
if not self.properties_schema:
|
||||
return {}
|
||||
return {prop.name: prop.default for prop in self.properties_schema if prop.default}
|
||||
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger provider
|
||||
"""
|
||||
|
||||
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The credentials schema of the trigger provider",
|
||||
)
|
||||
oauth_schema: Optional[OAuthSchema] = Field(
|
||||
default=None,
|
||||
description="The OAuth schema of the trigger provider if OAuth is supported",
|
||||
)
|
||||
subscription_schema: SubscriptionSchema = Field(
|
||||
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
|
||||
)
|
||||
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
"""
|
||||
Result of a successful trigger subscription operation.
|
||||
|
||||
Contains all information needed to manage the subscription lifecycle.
|
||||
"""
|
||||
|
||||
expires_at: int = Field(
|
||||
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
|
||||
)
|
||||
|
||||
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
|
||||
properties: Mapping[str, Any] = Field(
|
||||
..., description="Subscription data containing all properties and provider-specific information"
|
||||
)
|
||||
|
||||
|
||||
class Unsubscription(BaseModel):
|
||||
"""
|
||||
Result of a trigger unsubscription operation.
|
||||
|
||||
Provides detailed information about the unsubscription attempt,
|
||||
including success status and error details if failed.
|
||||
"""
|
||||
|
||||
success: bool = Field(..., description="Whether the unsubscription was successful")
|
||||
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="Human-readable message about the operation result. "
|
||||
"Success message for successful operations, "
|
||||
"detailed error information for failures.",
|
||||
)
|
||||
|
||||
|
||||
class RequestLog(BaseModel):
|
||||
id: str = Field(..., description="The id of the request log")
|
||||
endpoint: str = Field(..., description="The endpoint of the request log")
|
||||
request: dict = Field(..., description="The request of the request log")
|
||||
response: dict = Field(..., description="The response of the request log")
|
||||
created_at: datetime = Field(..., description="The created at of the request log")
|
||||
|
||||
|
||||
class SubscriptionBuilder(BaseModel):
|
||||
id: str = Field(..., description="The id of the subscription builder")
|
||||
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||
tenant_id: str = Field(..., description="The tenant id of the subscription builder")
|
||||
user_id: str = Field(..., description="The user id of the subscription builder")
|
||||
provider_id: str = Field(..., description="The provider id of the subscription builder")
|
||||
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
|
||||
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] = Field(..., description="The credentials of the subscription builder")
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
credential_expires_at: int | None = Field(
|
||||
default=None, description="The credential expires at of the subscription builder"
|
||||
)
|
||||
expires_at: int = Field(..., description="The expires at of the subscription builder")
|
||||
|
||||
def to_subscription(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint_id,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionBuilderUpdater(BaseModel):
|
||||
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] | None = Field(
|
||||
default=None, description="The credentials of the subscription builder"
|
||||
)
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
credential_expires_at: int | None = Field(
|
||||
default=None, description="The credential expires at of the subscription builder"
|
||||
)
|
||||
expires_at: int | None = Field(default=None, description="The expires at of the subscription builder")
|
||||
|
||||
def update(self, subscription_builder: SubscriptionBuilder) -> None:
|
||||
if self.name:
|
||||
subscription_builder.name = self.name
|
||||
if self.parameters:
|
||||
subscription_builder.parameters = self.parameters
|
||||
if self.properties:
|
||||
subscription_builder.properties = self.properties
|
||||
if self.credentials:
|
||||
subscription_builder.credentials = self.credentials
|
||||
if self.credential_type:
|
||||
subscription_builder.credential_type = self.credential_type
|
||||
if self.credential_expires_at:
|
||||
subscription_builder.credential_expires_at = self.credential_expires_at
|
||||
if self.expires_at:
|
||||
subscription_builder.expires_at = self.expires_at
|
||||
|
||||
|
||||
class TriggerEventData(BaseModel):
|
||||
"""Event data dispatched to trigger sessions."""
|
||||
|
||||
subscription_id: str
|
||||
triggers: list[str]
|
||||
request_id: str
|
||||
timestamp: float
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class TriggerInputs(BaseModel):
|
||||
"""Standard inputs for trigger nodes."""
|
||||
|
||||
request_id: str
|
||||
trigger_name: str
|
||||
subscription_id: str
|
||||
|
||||
def to_workflow_args(self) -> dict[str, Any]:
|
||||
"""Convert to workflow arguments format."""
|
||||
return {"inputs": self.model_dump(), "files": []}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dict (alias for model_dump)."""
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class TriggerCreationMethod(StrEnum):
|
||||
OAUTH = "OAUTH"
|
||||
APIKEY = "APIKEY"
|
||||
MANUAL = "MANUAL"
|
||||
|
||||
|
||||
# Export all entities
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
"RequestLog",
|
||||
"Subscription",
|
||||
"SubscriptionBuilder",
|
||||
"TriggerCreationMethod",
|
||||
"TriggerDescription",
|
||||
"TriggerEntity",
|
||||
"TriggerEventData",
|
||||
"TriggerIdentity",
|
||||
"TriggerInputs",
|
||||
"TriggerParameter",
|
||||
"TriggerParameterType",
|
||||
"TriggerProviderEntity",
|
||||
"TriggerProviderIdentity",
|
||||
"Unsubscription",
|
||||
]
|
||||
@ -1,8 +0,0 @@
|
||||
class TriggerProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
class TriggerInvokeError(Exception):
|
||||
pass
|
||||
|
||||
class TriggerIgnoreEventError(TriggerInvokeError):
|
||||
pass
|
||||
@ -1,358 +0,0 @@
|
||||
"""
|
||||
Trigger Provider Controller for managing trigger providers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import (
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
)
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerCreationMethod,
|
||||
TriggerEntity,
|
||||
TriggerProviderEntity,
|
||||
TriggerProviderIdentity,
|
||||
Unsubscription,
|
||||
)
|
||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginTriggerProviderController:
|
||||
"""
|
||||
Controller for plugin trigger providers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: TriggerProviderEntity,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
provider_id: TriggerProviderID,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Initialize plugin trigger provider controller
|
||||
|
||||
:param entity: Trigger provider entity
|
||||
:param plugin_id: Plugin ID
|
||||
:param plugin_unique_identifier: Plugin unique identifier
|
||||
:param provider_id: Provider ID
|
||||
:param tenant_id: Tenant ID
|
||||
"""
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.provider_id = provider_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_provider_id(self) -> TriggerProviderID:
|
||||
"""
|
||||
Get provider ID
|
||||
"""
|
||||
return self.provider_id
|
||||
|
||||
def to_api_entity(self) -> TriggerProviderApiEntity:
|
||||
"""
|
||||
Convert to API entity
|
||||
"""
|
||||
icon = (
|
||||
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon)
|
||||
if self.entity.identity.icon
|
||||
else None
|
||||
)
|
||||
icon_dark = (
|
||||
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon_dark)
|
||||
if self.entity.identity.icon_dark
|
||||
else None
|
||||
)
|
||||
supported_creation_methods = []
|
||||
if self.entity.oauth_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.OAUTH)
|
||||
if self.entity.credentials_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.APIKEY)
|
||||
if self.entity.subscription_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.MANUAL)
|
||||
return TriggerProviderApiEntity(
|
||||
author=self.entity.identity.author,
|
||||
name=self.entity.identity.name,
|
||||
label=self.entity.identity.label,
|
||||
description=self.entity.identity.description,
|
||||
icon=icon,
|
||||
icon_dark=icon_dark,
|
||||
tags=self.entity.identity.tags,
|
||||
plugin_id=self.plugin_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
credentials_schema=self.entity.credentials_schema,
|
||||
oauth_client_schema=self.entity.oauth_schema.client_schema if self.entity.oauth_schema else [],
|
||||
subscription_schema=self.entity.subscription_schema,
|
||||
supported_creation_methods=supported_creation_methods,
|
||||
triggers=[
|
||||
TriggerApiEntity(
|
||||
name=trigger.identity.name,
|
||||
identity=trigger.identity,
|
||||
description=trigger.description,
|
||||
parameters=trigger.parameters,
|
||||
output_schema=trigger.output_schema,
|
||||
)
|
||||
for trigger in self.entity.triggers
|
||||
],
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self) -> TriggerProviderIdentity:
|
||||
"""Get provider identity"""
|
||||
return self.entity.identity
|
||||
|
||||
def get_triggers(self) -> list[TriggerEntity]:
|
||||
"""
|
||||
Get all triggers for this provider
|
||||
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
return self.entity.triggers
|
||||
|
||||
def get_trigger(self, trigger_name: str) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger by name
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
for trigger in self.entity.triggers:
|
||||
if trigger.identity.name == trigger_name:
|
||||
return trigger
|
||||
return None
|
||||
|
||||
def get_subscription_schema(self) -> SubscriptionSchema:
|
||||
"""
|
||||
Get subscription schema for this provider
|
||||
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return self.entity.subscription_schema
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None:
|
||||
"""
|
||||
Validate credentials against schema
|
||||
|
||||
:param credentials: Credentials to validate
|
||||
:return: Validation response
|
||||
"""
|
||||
# First validate against schema
|
||||
for config in self.entity.credentials_schema:
|
||||
if config.required and config.name not in credentials:
|
||||
raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}")
|
||||
|
||||
# Then validate with the plugin daemon
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
response = manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
)
|
||||
if not response:
|
||||
raise TriggerProviderCredentialValidationError(
|
||||
"Invalid credentials",
|
||||
)
|
||||
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
Get supported credential types for this provider.
|
||||
|
||||
:return: List of supported credential types
|
||||
"""
|
||||
types = []
|
||||
if self.entity.oauth_schema:
|
||||
types.append(CredentialType.OAUTH2)
|
||||
if self.entity.credentials_schema:
|
||||
types.append(CredentialType.API_KEY)
|
||||
return types
|
||||
|
||||
def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get credentials schema by credential type
|
||||
|
||||
:param credential_type: The type of credential (oauth or api_key)
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
Get credential schema config by credential type
|
||||
"""
|
||||
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get OAuth client schema for this provider
|
||||
|
||||
:return: List of OAuth client config schemas
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_properties_schema(self) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
Get properties schema for this provider
|
||||
|
||||
:return: List of properties config schemas
|
||||
"""
|
||||
return (
|
||||
[x.to_basic_provider_config() for x in self.entity.subscription_schema.properties_schema.copy()]
|
||||
if self.entity.subscription_schema.properties_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param request: Flask request object
|
||||
:param subscription: Subscription
|
||||
:return: Dispatch response with triggers and raw HTTP response
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.dispatch_event(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
subscription=subscription.model_dump(),
|
||||
request=request,
|
||||
)
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
user_id: str,
|
||||
trigger_name: str,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Execute a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:param credential_type: Credential type
|
||||
:param request: Request
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
return manager.invoke_trigger(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
trigger=trigger_name,
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
request=request,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
def subscribe_trigger(
|
||||
self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str]
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param endpoint: Subscription endpoint
|
||||
:param subscription_params: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.subscribe(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
endpoint=endpoint,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
def unsubscribe_trigger(
|
||||
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str]
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param subscription: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.unsubscribe(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
subscription=subscription,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return Unsubscription.model_validate(response.subscription)
|
||||
|
||||
def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription through plugin runtime
|
||||
|
||||
:param subscription: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.refresh(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="system", # System refresh
|
||||
provider=str(provider_id),
|
||||
subscription=subscription,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
|
||||
__all__ = ["PluginTriggerProviderController"]
|
||||
@ -1,262 +0,0 @@
|
||||
"""
|
||||
Trigger Manager for loading and managing trigger providers and triggers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import Event, TriggerInvokeResponse
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.entities import (
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerEntity,
|
||||
Unsubscription,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager:
|
||||
"""
|
||||
Manager for trigger providers and triggers
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all plugin trigger providers for a tenant
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of trigger provider controllers
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_entities = manager.fetch_trigger_providers(tenant_id)
|
||||
|
||||
controllers = []
|
||||
for provider in provider_entities:
|
||||
try:
|
||||
controller = PluginTriggerProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
provider_id=TriggerProviderID(provider.provider),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
controllers.append(controller)
|
||||
except Exception:
|
||||
logger.exception("Failed to load trigger provider %s", provider.plugin_id)
|
||||
continue
|
||||
|
||||
return controllers
|
||||
|
||||
@classmethod
|
||||
def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController:
|
||||
"""
|
||||
Get a specific plugin trigger provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: Trigger provider controller or None
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_trigger_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_trigger_providers.set({})
|
||||
contexts.plugin_trigger_providers_lock.set(Lock())
|
||||
|
||||
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
|
||||
provider_id_str = str(provider_id)
|
||||
if provider_id_str in plugin_trigger_providers:
|
||||
return plugin_trigger_providers[provider_id_str]
|
||||
|
||||
with contexts.plugin_trigger_providers_lock.get():
|
||||
# double check
|
||||
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
|
||||
if provider_id_str in plugin_trigger_providers:
|
||||
return plugin_trigger_providers[provider_id_str]
|
||||
|
||||
manager = PluginTriggerManager()
|
||||
provider = manager.fetch_trigger_provider(tenant_id, provider_id)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Trigger provider {provider_id} not found")
|
||||
|
||||
try:
|
||||
controller = PluginTriggerProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
plugin_trigger_providers[provider_id_str] = controller
|
||||
return controller
|
||||
except Exception as e:
|
||||
logger.exception("Failed to load trigger provider")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all trigger providers (plugin)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of all trigger provider controllers
|
||||
"""
|
||||
return cls.list_plugin_trigger_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerEntity]:
|
||||
"""
|
||||
List all triggers for a specific provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.get_triggers()
|
||||
|
||||
@classmethod
|
||||
def get_trigger(cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Execute a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:param credential_type: Credential type
|
||||
:param request: Request
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
trigger = provider.get_trigger(trigger_name)
|
||||
if not trigger:
|
||||
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
|
||||
try:
|
||||
return provider.invoke_trigger(user_id, trigger_name, parameters, credentials, credential_type, request)
|
||||
except PluginInvokeError as e:
|
||||
if e.get_error_type() == "TriggerIgnoreEventError":
|
||||
return TriggerInvokeResponse(event=Event(variables={}), cancelled=True)
|
||||
else:
|
||||
logger.exception("Failed to invoke trigger")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def subscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger (e.g., register webhook)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param endpoint: Subscription endpoint
|
||||
:param parameters: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.subscribe_trigger(
|
||||
user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unsubscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param subscription: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> SubscriptionSchema:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||
|
||||
@classmethod
|
||||
def refresh_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials)
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = ["TriggerManager"]
|
||||
@ -1,145 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
|
||||
from core.helper.provider_cache import ProviderCredentialsCache
|
||||
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.trigger import TriggerSubscription
|
||||
|
||||
|
||||
class TriggerProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider OAuth client"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||
|
||||
|
||||
class TriggerProviderPropertiesCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider properties"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
subscription_id = kwargs["subscription_id"]
|
||||
return f"trigger_properties:tenant_id:{tenant_id}:provider_id:{provider_id}:subscription_id:{subscription_id}"
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=subscription.id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(subscription.credential_type),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_id=subscription_id,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderPropertiesCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_properties_schema(),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential_id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(credential_type),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_oauth_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderOAuthClientParamsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def masked_credentials(
|
||||
schemas: list[ProviderConfig],
|
||||
credentials: Mapping[str, str],
|
||||
) -> Mapping[str, str]:
|
||||
masked_credentials = {}
|
||||
configs = {x.name: x.to_basic_provider_config() for x in schemas}
|
||||
for key, value in credentials.items():
|
||||
config = configs.get(key)
|
||||
if not config:
|
||||
masked_credentials[key] = value
|
||||
continue
|
||||
if config.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if len(value) <= 4:
|
||||
masked_credentials[key] = "*" * len(value)
|
||||
else:
|
||||
masked_credentials[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
masked_credentials[key] = value
|
||||
return masked_credentials
|
||||
@ -1,5 +0,0 @@
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def parse_endpoint_id(endpoint_id: str) -> str:
|
||||
return f"{dify_config.CONSOLE_API_URL}/triggers/plugin/{endpoint_id}"
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, TypeAlias
|
||||
from typing import Annotated, TypeAlias, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
@ -86,7 +86,7 @@ class SecretVariable(StringVariable):
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return encrypter.obfuscated_token(self.value)
|
||||
return cast(str, encrypter.obfuscated_token(self.value))
|
||||
|
||||
|
||||
class NoneVariable(NoneSegment, Variable):
|
||||
|
||||
@ -25,7 +25,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
|
||||
@ -135,12 +135,12 @@ class Graph(BaseModel):
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use any start node (START or trigger types) as root node
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next(
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if NodeType(node_config.get("data", {}).get("type", "")).is_start_node
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -374,7 +374,7 @@ class GraphEngine:
|
||||
if len(sub_edge_mappings) == 0:
|
||||
continue
|
||||
|
||||
edge = sub_edge_mappings[0]
|
||||
edge = cast(GraphEdge, sub_edge_mappings[0])
|
||||
if edge.run_condition is None:
|
||||
logger.warning("Edge %s run condition is None", edge.target_node_id)
|
||||
continue
|
||||
|
||||
@ -153,7 +153,7 @@ class AgentNode(BaseNode):
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
@ -394,7 +394,8 @@ class AgentNode(BaseNode):
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
if f"{plugin.plugin_id}/{plugin.name}"
|
||||
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user