Compare commits

..

1 Commits

Author SHA1 Message Date
b1859c280f fix: model-credential 2025-08-28 16:45:27 +08:00
771 changed files with 5648 additions and 38675 deletions

View File

@ -1,19 +0,0 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,6 +1,5 @@
#!/bin/bash
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv

View File

@ -1,7 +1,13 @@
name: Run Pytest
on:
workflow_call:
pull_request:
branches:
- main
paths:
- api/**
- docker/**
- .github/workflows/api-tests.yml
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,9 +1,10 @@
name: autofix.ci
on:
workflow_call:
pull_request:
branches: ["main"]
branches: [ "main" ]
push:
branches: ["main"]
branches: [ "main" ]
permissions:
contents: read
@ -17,7 +18,7 @@ jobs:
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
python-version: "3.12"
- run: |
cd api
uv sync --dev
@ -28,7 +29,6 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat
run: |
uvx mdformat .

View File

@ -1,7 +1,13 @@
name: DB Migration Test
on:
workflow_call:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml
concurrency:
group: db-migration-test-${{ github.ref }}
@ -27,12 +33,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env
run: |

View File

@ -1,78 +0,0 @@
name: Main CI Pipeline
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: write
pull-requests: write
checks: write
statuses: write
concurrency:
group: main-ci-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
# Check which paths were changed to determine which tests to run
check-changes:
name: Check Changed Files
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
filters: |
api:
- 'api/**'
- 'docker/**'
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
- '.github/workflows/vdb-tests.yml'
- 'api/uv.lock'
- 'api/pyproject.toml'
migration:
- 'api/migrations/**'
- '.github/workflows/db-migration-test.yml'
# Run tests in parallel
api-tests:
name: API Tests
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
style-check:
name: Style Check
uses: ./.github/workflows/style.yml
vdb-tests:
name: VDB Tests
needs: check-changes
if: needs.check-changes.outputs.vdb-changed == 'true'
uses: ./.github/workflows/vdb-tests.yml
db-migration-test:
name: DB Migration Test
needs: check-changes
if: needs.check-changes.outputs.migration-changed == 'true'
uses: ./.github/workflows/db-migration-test.yml

View File

@ -1,7 +1,9 @@
name: Style check
on:
workflow_call:
pull_request:
branches:
- main
concurrency:
group: style-${{ github.head_ref || github.run_id }}
@ -44,10 +46,21 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: |
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
web-style:
name: Web Style
runs-on: ubuntu-latest
@ -89,9 +102,7 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
pnpm run lint
pnpm run eslint
run: pnpm run lint
docker-compose-template:
name: Docker Compose Template

View File

@ -1,7 +1,15 @@
name: Run VDB Tests
on:
workflow_call:
pull_request:
branches:
- main
paths:
- api/core/rag/datasource/**
- docker/**
- .github/workflows/vdb-tests.yml
- api/uv.lock
- api/pyproject.toml
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,7 +1,11 @@
name: Web Tests
on:
workflow_call:
pull_request:
branches:
- main
paths:
- web/**
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}

3
.gitignore vendored
View File

@ -218,6 +218,3 @@ mise.toml
.roo/
api/.env.backup
/clickzetta
# mcp
.serena

View File

@ -1,34 +0,0 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -59,7 +59,6 @@ pnpm test # Run Jest tests
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
### TypeScript/JavaScript
@ -87,4 +86,3 @@ pnpm test # Run Jest tests
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

View File

@ -434,9 +434,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Webhook request configuration
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
@ -505,12 +502,6 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
# Interval time in minutes for polling scheduled workflows(default: 1 min)
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Position configuration
POSITION_TOOL_PINS=

View File

@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion,workflow"
"dataset,generation,mail,ops_trace,app_deletion"
]
}
]

View File

@ -1207,55 +1207,6 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from core.plugin.entities.plugin import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.

View File

@ -147,17 +147,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class TriggerConfig(BaseSettings):
"""
Configuration for trigger
"""
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
description="Maximum allowed size for webhook request bodies in bytes",
default=10485760,
)
class PluginConfig(BaseSettings):
"""
Plugin configs
@ -882,22 +871,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable check upgradable plugin task",
default=True,
)
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
description="Enable workflow schedule poller task",
default=True,
)
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
description="Workflow schedule poller interval in minutes",
default=1,
)
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
description="Maximum number of schedules to process in each poll batch",
default=100,
)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
default=0,
)
class PositionConfig(BaseSettings):
@ -1021,7 +994,6 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
TriggerConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,

View File

@ -8,7 +8,6 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
from core.workflow.entities.variable_pool import VariablePool
@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
ContextVar("plugin_trigger_providers")
)
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_trigger_providers_lock")
)

View File

@ -67,11 +67,10 @@ from .app import (
workflow_draft_variable,
workflow_run,
workflow_statistic,
workflow_trigger,
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
# Import billing controllers
from .billing import billing, compliance
@ -181,6 +180,5 @@ from .workspace import (
models,
plugin,
tool_providers,
trigger_providers,
workspace,
)

View File

@ -12,7 +12,6 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
@ -126,11 +125,13 @@ class InstructionGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":

View File

@ -95,22 +95,18 @@ class ChatMessageListApi(Resource):
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
)
else:
# If we don't have a full page, there are no more messages
has_more = False
)
history_messages = list(reversed(history_messages))
@ -130,7 +126,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

View File

@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@ -39,7 +38,6 @@ from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.trigger_debug_service import TriggerDebugService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@ -808,132 +806,6 @@ class DraftWorkflowNodeLastRunApi(Resource):
return node_exec
class DraftWorkflowTriggerNodeApi(Resource):
"""
Single node debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Poll for trigger events and execute single node when event arrives
"""
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("trigger_name", type=str, required=True, location="json")
parser.add_argument("subscription_id", type=str, required=True, location="json")
args = parser.parse_args()
trigger_name = args["trigger_name"]
subscription_id = args["subscription_id"]
event = TriggerDebugService.poll_event(
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
subscription_id=subscription_id,
node_id=node_id,
trigger_name=trigger_name,
)
if not event:
return jsonable_encoder({"status": "waiting"})
try:
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
user_inputs = event.model_dump()
node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query="",
files=[],
)
return jsonable_encoder(node_execution)
except Exception:
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{
"status": "error",
}
), 500
class DraftWorkflowTriggerRunApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Poll for trigger events and execute full workflow when event arrives
"""
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
trigger_name = args["trigger_name"]
subscription_id = args["subscription_id"]
event = TriggerDebugService.poll_event(
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
subscription_id=subscription_id,
node_id=node_id,
trigger_name=trigger_name,
)
if not event:
return jsonable_encoder({"status": "waiting"})
workflow_args = {
"inputs": event.model_dump(),
"query": "",
"files": [],
}
external_trace_id = get_external_trace_id(request)
if external_trace_id:
workflow_args["external_trace_id"] = external_trace_id
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger run")
return jsonable_encoder(
{
"status": "error",
}
), 500
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
@ -958,14 +830,6 @@ api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
DraftWorkflowTriggerNodeApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger",
)
api.add_resource(
DraftWorkflowTriggerRunApi,
"/apps/<uuid:app_id>/workflows/draft/trigger/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",

View File

@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)

View File

@ -1,249 +0,0 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.model import Account, AppMode
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
logger = logging.getLogger(__name__)
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
class PluginTriggerApi(Resource):
"""Workflow Plugin Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def post(self, app_model):
"""Create plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=False, location="json")
parser.add_argument("provider_id", type=str, required=False, location="json")
parser.add_argument("trigger_name", type=str, required=False, location="json")
parser.add_argument("subscription_id", type=str, required=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
node_id=args["node_id"],
provider_id=args["provider_id"],
trigger_name=args["trigger_name"],
subscription_id=args["subscription_id"],
)
return jsonable_encoder(plugin_trigger)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def get(self, app_model):
"""Get plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
app_id=app_model.id,
node_id=args["node_id"],
)
return jsonable_encoder(plugin_trigger)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def put(self, app_model):
"""Update plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
app_id=app_model.id,
node_id=args["node_id"],
subscription_id=args["subscription_id"],
)
return jsonable_encoder(plugin_trigger)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def delete(self, app_model):
"""Delete plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
WorkflowPluginTriggerService.delete_plugin_trigger(
app_id=app_model.id,
node_id=args["node_id"],
)
return {"result": "success"}, 204
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
def get(self, app_model):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = args["node_id"]
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.filter(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
# Add computed fields for marshal_with
base_url = dify_config.SERVICE_API_URL
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
return webhook_trigger
class AppTriggersApi(Resource):
"""App Triggers list API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
def get(self, app_model):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
select(AppTrigger)
.where(
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
)
.scalars()
.all()
)
# Add computed icon field for each trigger
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for trigger in triggers:
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
class AppTriggerEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model):
"""Update app trigger (enable/disable)"""
parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
trigger_id = args["trigger_id"]
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
AppTrigger.id == trigger_id,
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
).scalar_one_or_none()
if not trigger:
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@ -1,187 +0,0 @@
from functools import wraps
from typing import cast
import flask_login
from flask import request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import api
def oauth_server_client_id_required(view):
@wraps(view)
def decorated(*args, **kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required")
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app
return view(*args, **kwargs)
return decorated
def oauth_server_access_token_required(view):
@wraps(view)
def decorated(*args, **kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
if not account:
raise BadRequest("access_token or client_id is invalid")
kwargs["account"] = account
return view(*args, **kwargs)
return decorated
class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
return jsonable_encoder(
{
"app_icon": oauth_provider_app.app_icon,
"app_label": oauth_provider_app.app_label,
"scope": oauth_provider_app.scope,
}
)
class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{
"code": code,
}
)
class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=False, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
class OAuthServerUserAccountApi(Resource):
@setup_required
@oauth_server_client_id_required
@oauth_server_access_token_required
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
return jsonable_encoder(
{
"name": account.name,
"email": account.email,
"avatar": account.avatar,
"interface_language": account.interface_language,
"timezone": account.timezone,
}
)
api.add_resource(OAuthServerAppApi, "/oauth/provider")
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")

View File

@ -516,20 +516,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
parser.add_argument("provider", type=str, required=True, location="args")
parser.add_argument("action", type=str, required=True, location="args")
parser.add_argument("parameter", type=str, required=True, location="args")
parser.add_argument("credential_id", type=str, required=False, location="args")
parser.add_argument("provider_type", type=str, required=True, location="args")
args = parser.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args["plugin_id"],
provider=args["provider"],
action=args["action"],
parameter=args["parameter"],
credential_id=args["credential_id"],
provider_type=args["provider_type"],
tenant_id,
user_id,
args["plugin_id"],
args["provider"],
args["action"],
args["parameter"],
args["provider_type"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)

View File

@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService

View File

@ -1,589 +0,0 @@
import logging
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from libs.login import current_user, login_required
from models.account import Account
from services.plugin.oauth_service import OAuthProxyService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
logger = logging.getLogger(__name__)
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
)
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
raise
class TriggerSubscriptionBuilderCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
)
return jsonable_encoder({"subscription_builder": subscription_builder})
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error adding provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
),
)
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
)
except Exception as e:
logger.exception("Error verifying provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
except Exception as e:
logger.exception("Error updating provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
except Exception as e:
logger.exception("Error getting request logs for subscription builder", exc_info=e)
raise
class TriggerSubscriptionBuilderBuildApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
)
return 200
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error building provider credential", exc_info=e)
raise
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, subscription_id):
"""Delete a subscription instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
with Session(db.engine) as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
# Delete plugin triggers
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error deleting provider credential", exc_info=e)
raise
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Create subscription builder
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
)
# Create OAuth handler and proxy context
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
extra_data={
"subscription_builder_id": subscription_builder.id,
},
)
# Build redirect URI for callback
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
# Get authorization URL
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(
jsonable_encoder(
{
"authorization_url": authorization_url_response.authorization_url,
"subscription_builder_id": subscription_builder.id,
"subscription_builder": subscription_builder,
}
)
)
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e)
raise
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
# Use and validate proxy context
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
# Parse provider ID
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Get OAuth credentials from callback
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise Exception("Failed to get OAuth credentials")
# Update subscription builder
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=credentials,
credential_expires_at=expires_at,
),
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
# Get custom OAuth client params if exists
custom_params = TriggerProviderService.get_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if custom client is enabled
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if there's a system OAuth client
system_client = TriggerProviderService.get_oauth_client(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
return jsonable_encoder(
{
"configured": bool(custom_params or system_client),
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
"custom_configured": bool(custom_params),
"custom_enabled": is_custom_enabled,
"redirect_uri": redirect_uri,
"params": custom_params if custom_params else {},
}
)
except Exception as e:
logger.exception("Error getting OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error configuring OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
# Trigger Subscription
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
api.add_resource(
TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
# Trigger Subscription Builder
api.add_resource(
TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
api.add_resource(
TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
# OAuth
api.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
)
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")

View File

@ -1,12 +1,8 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -14,9 +10,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view: Callable[P, R]):
def billing_inner_api_only(view):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
abort(404)
@ -30,9 +26,9 @@ def billing_inner_api_only(view: Callable[P, R]):
return decorated
def enterprise_inner_api_only(view: Callable[P, R]):
def enterprise_inner_api_only(view):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
abort(404)
@ -82,9 +78,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view: Callable[P, R]):
def plugin_inner_api_only(view):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -1,28 +1,20 @@
from typing import Optional, Union
from flask import Response
from flask_restx import Resource, reqparse
from pydantic import ValidationError
from sqlalchemy.orm import Session
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp import types
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from core.mcp.utils import create_mcp_error_response
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
class MCPRequestError(Exception):
"""Custom exception for MCP request processing errors"""
def __init__(self, error_code: int, message: str):
self.error_code = error_code
self.message = message
super().__init__(message)
def int_or_str(value):
"""Validate that a value is either an integer or string."""
if isinstance(value, (int, str)):
@ -71,128 +63,76 @@ class MCPAppApi(Resource):
Raises:
ValidationError: Invalid request format or parameters
"""
# Parse and validate all arguments
args = mcp_request_parser.parse_args()
request_id: Optional[Union[int, str]] = args.get("id")
mcp_request = self._parse_mcp_request(args)
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
)
# Get user input form
user_input_form = self._get_user_input_form(app)
if server.status != AppMCPServerStatus.ACTIVE:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
# Handle notification vs request differently
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.query(App).where(App.id == mcp_server.app_id).first()
app = db.session.query(App).where(App.id == server.app_id).first()
if not app:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
)
return mcp_server, app
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
"""Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
def _process_mcp_message(
self,
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Process MCP message (notification or request)"""
if isinstance(mcp_request, mcp_types.ClientNotification):
return self._handle_notification(mcp_request)
else:
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
"""Handle MCP notification"""
# For notifications, only support init notification
if mcp_request.root.method != "notifications/initialized":
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
# Return HTTP 202 Accepted for notifications (no response body)
return Response("", status=202, content_type="application/json")
def _handle_request(
self,
mcp_request: mcp_types.ClientRequest,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Handle MCP request"""
if request_id is None:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
if result is None:
# This shouldn't happen for requests, but handle gracefully
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
"""Get and convert user input form"""
# Get raw user input form based on app mode
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if not app.workflow:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
workflow = app.workflow
if workflow is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
if not app.app_model_config:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
features_dict = app.app_model_config.to_dict()
raw_user_input_form = features_dict.get("user_input_form", [])
app_model_config = app.app_model_config
if app_model_config is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
# Convert to VariableEntity objects
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
converted_user_input_form: list[VariableEntity] = []
try:
return self._convert_user_input_form(raw_user_input_form)
for item in user_input_form:
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
converted_user_input_form.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
)
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
)
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
return VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)
except ValidationError:
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
except ValidationError as e:
try:
return mcp_types.ClientNotification.model_validate(args)
notification = ClientNotification.model_validate(args)
request = notification
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
)
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()

View File

@ -318,6 +318,10 @@ class DatasetApi(DatasetApiResource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)

View File

@ -1,6 +1,6 @@
from typing import Literal
from flask_login import current_user
from flask_login import current_user # type: ignore
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound

View File

@ -1,12 +1,12 @@
import time
from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from enum import Enum
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
from flask_login import user_logged_in # type: ignore
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(StrEnum):
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = auto()
JSON = auto()
FORM = auto()
QUERY = "query"
JSON = "json"
FORM = "form"
class FetchUserArg(BaseModel):

View File

@ -1,7 +0,0 @@
from flask import Blueprint
# Create trigger blueprint
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
# Import routes after blueprint creation to avoid circular imports
from . import trigger, webhook

View File

@ -1,41 +0,0 @@
import logging
import re
from flask import jsonify, request
from werkzeug.exceptions import NotFound
from controllers.trigger import bp
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger_service import TriggerService
logger = logging.getLogger(__name__)
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
UUID_MATCHER = re.compile(UUID_PATTERN)
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def trigger_endpoint(endpoint_id: str):
"""
Handle endpoint trigger calls.
"""
# endpoint_id must be UUID
if not UUID_MATCHER.match(endpoint_id):
raise NotFound("Invalid endpoint ID")
handling_chain = [
TriggerService.process_endpoint,
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
]
try:
for handler in handling_chain:
response = handler(endpoint_id, request)
if response:
break
if not response:
raise NotFound("Endpoint not found")
return response
except ValueError as e:
raise NotFound(str(e))
except Exception as e:
logger.exception("Webhook processing failed for {endpoint_id}")
return jsonify({"error": "Internal server error", "message": str(e)}), 500

View File

@ -1,46 +0,0 @@
import logging
from flask import jsonify
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from services.webhook_service import WebhookService
logger = logging.getLogger(__name__)
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook(webhook_id: str):
"""
Handle webhook trigger calls.
This endpoint receives webhook calls and processes them according to the
configured webhook trigger settings.
"""
try:
# Get webhook trigger, workflow, and node configuration
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
# Extract request data
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
# Validate request against node configuration
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
if not validation_result["valid"]:
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
# Process webhook call (send to Celery)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Return configured response
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code
except ValueError as e:
raise NotFound(str(e))
except RequestEntityTooLarge:
raise
except Exception as e:
logger.exception("Webhook processing failed for %s", webhook_id)
return jsonify({"error": "Internal server error", "message": str(e)}), 500

View File

@ -1,5 +1,5 @@
from flask_restx import Resource, reqparse
from jwt import InvalidTokenError
from jwt import InvalidTokenError # type: ignore
import services
from controllers.console.auth.error import (

View File

@ -1,16 +1,3 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationError
class MoreLikeThisConfig(BaseModel):
enabled: bool = False
model_config = ConfigDict(extra="allow")
class AppConfigModel(BaseModel):
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
model_config = ConfigDict(extra="allow")
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
@ -19,14 +6,31 @@ class MoreLikeThisConfigManager:
:param config: model config args
"""
validated_config, _ = cls.validate_and_set_defaults(config)
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
more_like_this = False
more_like_this_dict = config.get("more_like_this")
if more_like_this_dict:
if more_like_this_dict.get("enabled"):
more_like_this = True
return more_like_this
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)
"""
Validate and set defaults for more like this feature
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {"enabled": False}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
config["more_like_this"]["enabled"] = False
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
return config, ["more_like_this"]

View File

@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=conversation_variables,
conversation_variables=cast(list[VariableUnion], conversation_variables),
)
# init graph

View File

@ -1,5 +1,4 @@
import logging
import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
@ -144,7 +143,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._task_state = WorkflowTaskState()
@ -375,7 +373,7 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
@ -898,14 +896,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
message.answer = answer_text
message.answer = self._task_state.answer
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = self._task_state.metadata.model_dump_json()

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, final
@ -13,7 +14,6 @@ from core.workflow.repositories.draft_variable_repository import (
NoopDraftVariableSaver,
)
from factories import file_factory
from libs.orjson import orjson_dumps
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
@ -174,7 +174,7 @@ class BaseAppGenerator:
def gen():
for message in generator:
if isinstance(message, Mapping | dict):
yield f"data: {orjson_dumps(message)}\n\n"
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

View File

@ -1,7 +1,7 @@
import queue
import time
from abc import abstractmethod
from enum import IntEnum, auto
from enum import Enum
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
TASK_PIPELINE = auto()
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class AppQueueManager:

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
@ -52,7 +53,9 @@ from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowRun,
)
@ -61,10 +64,8 @@ class WorkflowResponseConverter:
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
) -> None:
self._application_generate_entity = application_generate_entity
self._user = user
def workflow_start_to_stream_response(
self,
@ -91,21 +92,27 @@ class WorkflowResponseConverter:
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
elif isinstance(user, EndUser):
created_by = {
"id": user.id,
"user": user.session_id,
}
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
if account:
created_by = {
"id": account.id,
"name": account.name,
"email": account.email,
}
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
}
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (

View File

@ -54,8 +54,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -70,8 +68,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
@ -86,8 +82,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
@ -101,8 +95,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -138,20 +130,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
# start node get inputs
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=list(system_files),
user_id=user.id,
stream=streaming,
@ -170,10 +159,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
@ -201,7 +187,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
root_node_id=root_node_id,
)
def _generate(
@ -217,7 +202,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -255,7 +239,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
},
)
@ -452,7 +435,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
root_node_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
@ -496,7 +478,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
root_node_id=root_node_id,
)
try:

View File

@ -34,7 +34,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
root_node_id: Optional[str] = None,
) -> None:
super().__init__(
queue_manager=queue_manager,
@ -45,7 +44,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
def run(self) -> None:
"""
@ -95,7 +93,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
graph = self._init_graph(graph_config=self._workflow.graph_dict)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(

View File

@ -131,7 +131,6 @@ class WorkflowAppGenerateTaskPipeline:
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._application_generate_entity = application_generate_entity

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
@ -79,7 +79,7 @@ class WorkflowBasedAppRunner:
self._variable_loader = variable_loader
self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
@ -92,7 +92,7 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")

View File

@ -118,7 +118,7 @@ class QueueIterationNextEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@ -201,7 +201,7 @@ class QueueLoopNextEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
duration: Optional[float] = None
@ -382,7 +382,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""loop id if node is in loop"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
"""iteratoin run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None

View File

@ -472,10 +472,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
return AgentThoughtStreamResponse(

View File

@ -1,8 +1,8 @@
class TaskPipelineError(ValueError):
class TaskPipilineError(ValueError):
pass
class RecordNotFoundError(TaskPipelineError):
class RecordNotFoundError(TaskPipilineError):
def __init__(self, record_name: str, record_id: str):
super().__init__(f"{record_name} with id {record_id} not found")

View File

@ -192,7 +192,7 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str, float, bool, list]] = None
default: Optional[Union[int, str, float, bool]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None

View File

@ -88,7 +88,6 @@ def to_prompt_message_content(
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa
def obfuscated_token(token: str) -> str:
def obfuscated_token(token: str):
if not token:
return token
if len(token) <= 8:

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
import httpx
import requests
from yarl import URL
from configs import dify_config
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

View File

@ -1,128 +0,0 @@
import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
return self.mask_credentials(data)
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache

View File

@ -4,259 +4,224 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.mcp.utils import create_mcp_error_response
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
mcp_server: AppMCPServer,
end_user: EndUser | None = None,
request_id: int | str = 1,
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
class MCPServerStreamableHTTPRequestHandler:
"""
Handle MCP request and return JSON-RPC response
Args:
app: The Dify app instance
request: The JSON-RPC request message
user_input_form: List of variable entities for the app
mcp_server: The MCP server configuration
end_user: Optional end user
request_id: The request ID
Returns:
JSON-RPC response or error
Apply to MCP HTTP streamable server with stateless http
"""
request_type = type(request.root)
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data"""
return mcp_types.JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
)
@property
def request_type(self):
return type(self.request.root)
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
"""Create error response with error code and message"""
from core.mcp.types import ErrorData
error_data = ErrorData(code=code, message=message)
return mcp_types.JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=error_data,
)
# Request handler mapping using functional approach
request_handlers = {
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
mcp_types.ListToolsRequest: lambda: handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
),
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
mcp_types.PingRequest: lambda: handle_ping(),
}
try:
# Dispatch request to appropriate handler
handler = request_handlers.get(request_type)
if handler:
return create_success_response(handler())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
except ValueError as e:
logger.exception("Invalid params")
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
def handle_ping() -> mcp_types.EmptyResult:
"""Handle ping request"""
return mcp_types.EmptyResult()
def handle_initialize(description: str) -> mcp_types.InitializeResult:
"""Handle initialize request"""
capabilities = mcp_types.ServerCapabilities(
tools=mcp_types.ToolsCapability(listChanged=False),
)
return mcp_types.InitializeResult(
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
instructions=description,
)
def handle_list_tools(
app_name: str,
app_mode: str,
user_input_form: list[VariableEntity],
description: str,
parameters_dict: dict[str, str],
) -> mcp_types.ListToolsResult:
"""Handle list tools request"""
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
return mcp_types.ListToolsResult(
tools=[
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=parameter_schema,
)
],
)
def handle_call_tool(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
end_user: EndUser | None,
) -> mcp_types.CallToolResult:
"""Handle call tool request"""
request_obj = cast(mcp_types.CallToolRequest, request.root)
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
if not end_user:
raise ValueError("End user not found")
response = AppGenerateService.generate(
app,
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
)
answer = extract_answer_from_response(app, response)
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": parameters,
"required": required,
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
@property
def capabilities(self):
return types.ServerCapabilities(
tools=types.ToolsCapability(listChanged=False),
)
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
def extract_answer_from_response(app: App, response: Any) -> str:
"""Extract answer from app generate response"""
answer = ""
yield sse_content
if isinstance(response, RateLimitGenerator):
answer = process_streaming_response(response)
elif isinstance(response, Mapping):
answer = process_mapping_response(app, response)
else:
logger.warning("Unexpected response type: %s", type(response))
def error_response(self, code: int, message: str, data=None):
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
return create_mcp_error_response(request_id, code, message, data)
return answer
def handle(self):
handle_map = {
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
types.PingRequest: self.handle_ping,
}
try:
if self.request_type in handle_map:
return self.response(handle_map[self.request_type]())
else:
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
except ValueError as e:
logger.exception("Invalid params")
return self.error_response(INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
def handle_notification(self):
return "ping"
def process_streaming_response(response: RateLimitGenerator) -> str:
"""Process streaming response for agent chat mode"""
answer = ""
for item in response.generator:
if isinstance(item, str) and item.startswith("data: "):
try:
json_str = item[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
def handle_ping(self):
return types.EmptyResult()
def initialize(self):
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
if not self.end_user:
end_user = EndUser(
tenant_id=self.app.tenant_id,
app_id=self.app.id,
type="mcp",
name=client_name,
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
instructions=self.mcp_server.description,
)
def list_tools(self):
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
tools=[
types.Tool(
name=self.app.name,
description=self.mcp_server.description,
inputSchema=self.parameter_schema,
)
],
)
def invoke_tool(self):
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
args = request.params.arguments or {}
if self.app.mode in {AppMode.WORKFLOW.value}:
args = {"inputs": args}
elif self.app.mode in {AppMode.COMPLETION.value}:
args = {"query": "", "inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(
self.app,
self.end_user,
args,
InvokeFrom.SERVICE_API,
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
)
answer = ""
if isinstance(response, RateLimitGenerator):
for item in response.generator:
data = item
if isinstance(data, str) and data.startswith("data: "):
try:
json_str = data[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
if isinstance(response, Mapping):
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
parameters[item.variable] = {}
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
return answer
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> tuple[dict[str, dict[str, Any]], list[str]]:
"""Convert user input form to parameter schema"""
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
parameters[item.variable] = {}
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
description = parameters_dict.get(item.variable, "")
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
try:
description = self.mcp_server.parameters_dict[item.variable]
except KeyError:
description = ""
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required

View File

@ -138,5 +138,5 @@ def create_mcp_error_response(
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = json_data.encode()
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content

View File

@ -31,65 +31,6 @@ class TokenBufferMemory:
self.conversation = conversation
self.model_instance = model_instance
def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
) -> PromptMessage:
"""
Build prompt message with files.
:param message_files: list of MessageFile objects
:param text_content: text content of the message
:param message: Message object
:param app_record: app record
:param is_user_message: whether this is a user message
:return: PromptMessage
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record:
# Build files directly without filtering by belongs_to
file_objs = [
file_factory.build_from_message_file(
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
)
for message_file in message_files
]
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
if not file_objs:
if is_user_message:
return UserPromptMessage(content=text_content)
else:
return AssistantPromptMessage(content=text_content)
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=text_content))
if is_user_message:
return UserPromptMessage(content=prompt_message_contents)
else:
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
@ -126,46 +67,52 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
.all()
)
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
message_files=user_files,
text_content=message.query,
message=message,
app_record=app_record,
is_user_message=True,
)
prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(
message_files=assistant_files,
text_content=message.answer,
message=message,
app_record=app_record,
is_user_message=False,
)
prompt_messages.append(assistant_prompt_message)
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages:
return []

View File

@ -158,6 +158,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -186,6 +188,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@ -210,6 +214,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@ -231,6 +237,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@ -261,6 +269,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@ -285,6 +295,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@ -306,6 +318,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@ -329,6 +343,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@ -388,6 +404,8 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

View File

@ -87,7 +87,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
@property
def data(self):

View File

@ -43,7 +43,7 @@ class GPT2Tokenizer:
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")

View File

@ -330,7 +330,7 @@ class OpsTraceManager:
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
else:
if tracing_provider is None:
if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()

View File

@ -375,16 +375,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens
messages: list[str] = []
for line in new_lines:
for i in new_lines: # type: ignore
if len(messages) == 0:
messages.append(line)
messages.append(i) # type: ignore
else:
if len(messages[-1]) + len(line) < max_tokens * 0.5:
messages[-1] += line
if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
messages.append(line)
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) # type: ignore
else:
messages[-1] += line
messages[-1] += i # type: ignore
summaries = []
for i in range(len(messages)):

View File

@ -13,7 +13,6 @@ from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity
from core.trigger.entities.entities import TriggerProviderEntity
class PluginInstallationSource(enum.StrEnum):
@ -63,7 +62,6 @@ class PluginCategory(enum.StrEnum):
Model = "model"
Extension = "extension"
AgentStrategy = "agent-strategy"
Trigger = "trigger"
class PluginDeclaration(BaseModel):
@ -71,7 +69,6 @@ class PluginDeclaration(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
triggers: Optional[list[str]] = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
@ -92,7 +89,6 @@ class PluginDeclaration(BaseModel):
repo: Optional[str] = Field(default=None)
verified: bool = Field(default=False)
tool: Optional[ToolProviderEntity] = None
trigger: Optional[TriggerProviderEntity] = None
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
@ -108,8 +104,6 @@ class PluginDeclaration(BaseModel):
values["category"] = PluginCategory.Model
elif values.get("agent_strategy"):
values["category"] = PluginCategory.AgentStrategy
elif values.get("trigger"):
values["category"] = PluginCategory.Trigger
else:
values["category"] = PluginCategory.Extension
return values
@ -190,10 +184,6 @@ class ToolProviderID(GenericProviderID):
self.plugin_name = f"{self.provider_name}_tool"
class TriggerProviderID(GenericProviderID):
pass
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value

View File

@ -1,4 +1,3 @@
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@ -14,7 +13,6 @@ from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@ -198,48 +196,3 @@ class PluginListResponse(BaseModel):
class PluginDynamicSelectOptionsResponse(BaseModel):
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
class PluginTriggerProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
declaration: TriggerProviderEntity
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
UNAUTHORIZED = "unauthorized"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
elif self == CredentialType.UNAUTHORIZED:
return "UNAUTHORIZED"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
elif type_name == "unauthorized":
return cls.UNAUTHORIZED
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@ -1,7 +1,5 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from flask import Response
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.provider_entities import BasicProviderConfig
@ -239,33 +237,3 @@ class RequestFetchAppInfo(BaseModel):
"""
app_id: str
class Event(BaseModel):
variables: Mapping[str, Any]
class TriggerInvokeResponse(BaseModel):
event: Event
class PluginTriggerDispatchResponse(BaseModel):
triggers: list[str]
raw_http_response: str
class TriggerSubscriptionResponse(BaseModel):
subscription: dict[str, Any]
class TriggerValidateProviderCredentialsResponse(BaseModel):
result: bool
class TriggerDispatchResponse:
triggers: list[str]
response: Response
def __init__(self, triggers: list[str], response: Response):
self.triggers = triggers
self.response = response

View File

@ -15,7 +15,6 @@ class DynamicSelectClient(BasePluginClient):
provider: str,
action: str,
credentials: Mapping[str, Any],
credential_type: str,
parameter: str,
) -> PluginDynamicSelectOptionsResponse:
"""
@ -30,7 +29,6 @@ class DynamicSelectClient(BasePluginClient):
"data": {
"provider": GenericProviderID(provider).provider_name,
"credentials": credentials,
"credential_type": credential_type,
"provider_action": action,
"parameter": parameter,
},

View File

@ -4,10 +4,10 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):

View File

@ -1,301 +0,0 @@
import binascii
from collections.abc import Mapping
from typing import Any
from flask import Request
from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
from core.plugin.entities.request import (
PluginTriggerDispatchResponse,
TriggerDispatchResponse,
TriggerInvokeResponse,
TriggerSubscriptionResponse,
TriggerValidateProviderCredentialsResponse,
)
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.http_parser import deserialize_response, serialize_request
from core.trigger.entities.entities import Subscription
class PluginTriggerManager(BasePluginClient):
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
"""
Fetch trigger providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
for trigger in declaration.get("triggers", []):
trigger["identity"]["provider"] = provider_id
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/triggers",
list[PluginTriggerProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each trigger to plugin_id/provider_name
for trigger in provider.declaration.triggers:
trigger.identity.provider = provider.declaration.identity.name
return response
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
"""
Fetch trigger provider for the given tenant and plugin.
"""
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for trigger in data.get("declaration", {}).get("triggers", []):
trigger["identity"]["provider"] = str(provider_id)
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/trigger",
PluginTriggerProviderEntity,
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = str(provider_id)
# override the provider name for each trigger to plugin_id/provider_name
for trigger in response.declaration.triggers:
trigger.identity.provider = str(provider_id)
return response
def invoke_trigger(
self,
tenant_id: str,
user_id: str,
provider: str,
trigger: str,
credentials: Mapping[str, str],
credential_type: CredentialType,
request: Request,
parameters: Mapping[str, Any],
) -> TriggerInvokeResponse:
"""
Invoke a trigger with the given parameters.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/invoke",
TriggerInvokeResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"trigger": trigger,
"credentials": credentials,
"credential_type": credential_type,
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
"parameters": parameters,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return TriggerInvokeResponse(event=resp.event)
raise ValueError("No response received from plugin daemon for invoke trigger")
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
) -> bool:
"""
Validate the credentials of the trigger provider.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
TriggerValidateProviderCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
raise ValueError("No response received from plugin daemon for validate provider credentials")
def dispatch_event(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Mapping[str, Any],
request: Request,
) -> TriggerDispatchResponse:
"""
Dispatch an event to triggers.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
PluginTriggerDispatchResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"subscription": subscription,
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return TriggerDispatchResponse(
triggers=resp.triggers,
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
)
raise ValueError("No response received from plugin daemon for dispatch event")
def subscribe(
self,
tenant_id: str,
user_id: str,
provider: str,
credentials: Mapping[str, str],
endpoint: str,
parameters: Mapping[str, Any],
) -> TriggerSubscriptionResponse:
"""
Subscribe to a trigger.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"credentials": credentials,
"endpoint": endpoint,
"parameters": parameters,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for subscribe")
def unsubscribe(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Subscription,
credentials: Mapping[str, str],
) -> TriggerSubscriptionResponse:
"""
Unsubscribe from a trigger.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for unsubscribe")
def refresh(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Subscription,
credentials: Mapping[str, str],
) -> TriggerSubscriptionResponse:
"""
Refresh a trigger subscription.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/refresh",
TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh")

View File

@ -1,159 +0,0 @@
from io import BytesIO
from flask import Request, Response
from werkzeug.datastructures import Headers
def serialize_request(request: Request) -> bytes:
method = request.method
path = request.full_path.rstrip("?")
raw = f"{method} {path} HTTP/1.1\r\n".encode()
for name, value in request.headers.items():
raw += f"{name}: {value}\r\n".encode()
raw += b"\r\n"
body = request.get_data(as_text=False)
if body:
raw += body
return raw
def deserialize_request(raw_data: bytes) -> Request:
header_end = raw_data.find(b"\r\n\r\n")
if header_end == -1:
header_end = raw_data.find(b"\n\n")
if header_end == -1:
header_data = raw_data
body = b""
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 2 :]
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 4 :]
lines = header_data.split(b"\r\n")
if len(lines) == 1 and b"\n" in lines[0]:
lines = header_data.split(b"\n")
if not lines or not lines[0]:
raise ValueError("Empty HTTP request")
request_line = lines[0].decode("utf-8", errors="ignore")
parts = request_line.split(" ", 2)
if len(parts) < 2:
raise ValueError(f"Invalid request line: {request_line}")
method = parts[0]
full_path = parts[1]
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
if "?" in full_path:
path, query_string = full_path.split("?", 1)
else:
path = full_path
query_string = ""
headers = Headers()
for line in lines[1:]:
if not line:
continue
line_str = line.decode("utf-8", errors="ignore")
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
headers.add(name, value.strip())
host = headers.get("Host", "localhost")
if ":" in host:
server_name, server_port = host.rsplit(":", 1)
else:
server_name = host
server_port = "80"
environ = {
"REQUEST_METHOD": method,
"PATH_INFO": path,
"QUERY_STRING": query_string,
"SERVER_NAME": server_name,
"SERVER_PORT": server_port,
"SERVER_PROTOCOL": protocol,
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
}
if "Content-Type" in headers:
environ["CONTENT_TYPE"] = headers.get("Content-Type")
if "Content-Length" in headers:
environ["CONTENT_LENGTH"] = headers.get("Content-Length")
elif body:
environ["CONTENT_LENGTH"] = str(len(body))
for name, value in headers.items():
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
continue
env_name = f"HTTP_{name.upper().replace('-', '_')}"
environ[env_name] = value
return Request(environ)
def serialize_response(response: Response) -> bytes:
raw = f"HTTP/1.1 {response.status}\r\n".encode()
for name, value in response.headers.items():
raw += f"{name}: {value}\r\n".encode()
raw += b"\r\n"
body = response.get_data(as_text=False)
if body:
raw += body
return raw
def deserialize_response(raw_data: bytes) -> Response:
header_end = raw_data.find(b"\r\n\r\n")
if header_end == -1:
header_end = raw_data.find(b"\n\n")
if header_end == -1:
header_data = raw_data
body = b""
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 2 :]
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 4 :]
lines = header_data.split(b"\r\n")
if len(lines) == 1 and b"\n" in lines[0]:
lines = header_data.split(b"\n")
if not lines or not lines[0]:
raise ValueError("Empty HTTP response")
status_line = lines[0].decode("utf-8", errors="ignore")
parts = status_line.split(" ", 2)
if len(parts) < 2:
raise ValueError(f"Invalid status line: {status_line}")
status_code = int(parts[1])
response = Response(response=body, status=status_code)
for line in lines[1:]:
if not line:
continue
line_str = line.decode("utf-8", errors="ignore")
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
response.headers[name] = value.strip()
return response

View File

@ -87,6 +87,7 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

View File

@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
data=provider_entity,
name_func=lambda x: x.provider,
):

View File

@ -24,7 +24,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"top_k": 2,
"score_threshold_enabled": False,
}

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document

View File

@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
top_k = kwargs.get("top_k", 2)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config

View File

@ -48,7 +48,7 @@ class OpenSearchConfig(BaseModel):
return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),

View File

@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Any
import psycopg2.errors
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
from flask import current_app
@ -426,6 +426,7 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod

View File

@ -4,7 +4,7 @@ import os
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -2,7 +2,7 @@
import re
from pathlib import Path
from typing import Optional
from typing import Optional, cast
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))
markdown_tups = [
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]

View File

@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return data_source_binding.access_token
return cast(str, data_source_binding.access_token)

View File

@ -2,7 +2,7 @@
import contextlib
from collections.abc import Iterator
from typing import Optional
from typing import Optional, cast
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
text = storage.load(self._file_cache_key).decode("utf-8")
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())

View File

@ -3,7 +3,7 @@ import contextlib
import logging
from typing import Optional
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"top_k": 2,
"score_threshold_enabled": False,
}
@ -647,7 +647,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@ -743,7 +743,7 @@ class DatasetRetrieval:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 4,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,

View File

@ -144,7 +144,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase # type: ignore
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")

View File

@ -4,8 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):

View File

@ -4,11 +4,11 @@ from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,

View File

@ -4,10 +4,9 @@ from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):

View File

@ -476,3 +476,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@ -37,7 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import Tool
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -48,6 +47,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
@ -331,13 +331,16 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@ -645,8 +648,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):

View File

@ -181,7 +181,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
top_k=retrieval_model.get("top_k") or 2,
)
if documents:
all_documents.extend(documents)
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 4
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool

View File

@ -1,23 +1,132 @@
# Import generic components from provider_encryption module
from core.helper.provider_encryption import (
ProviderConfigCache,
ProviderConfigEncrypter,
create_provider_encrypter,
)
import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol
# Re-export for backward compatibility
__all__ = [
"ProviderConfigCache",
"ProviderConfigEncrypter",
"create_provider_encrypter",
"create_tool_provider_encrypter",
]
# Tool-specific imports
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,

View File

@ -3,7 +3,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional
from typing import Optional, cast
from uuid import UUID
import numpy as np
@ -159,7 +159,8 @@ class ToolFileMessageTransformer:
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
message.message.json_object = safe_json_value(message.message.json_object)
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
yield message
else:
yield message

View File

@ -129,14 +129,17 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@ -6,7 +6,7 @@ from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle

View File

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional
from typing import Any, Optional, cast
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@ -204,14 +204,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
)
files.append(file)

View File

@ -1 +0,0 @@
# Core trigger module initialization

Some files were not shown because too many files have changed in this diff Show More