Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice
2025-12-30 10:20:42 +08:00
232 changed files with 18692 additions and 2696 deletions

View File

@ -155,6 +155,7 @@ class AppDslService:
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
and "/blob/" in parsed_url.path
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")

View File

@ -14,7 +14,8 @@ from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService

View File

@ -21,7 +21,7 @@ from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -141,7 +141,7 @@ class AsyncWorkflowService:
trigger_log_repo.update(trigger_log)
session.commit()
raise InvokeRateLimitError(
raise WorkflowQuotaLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e

View File

@ -26,7 +26,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
response = self._post_request(self._build_url("v1/crawl"), options, headers)
if response.status_code == 200:
return True
else:
@ -35,15 +35,17 @@ class FirecrawlAuth(ApiKeyAuthBase):
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _build_url(self, path: str) -> str:
# ensure exactly one slash between base and path, regardless of user-provided base_url
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
try:
payload = response.json()
except json.JSONDecodeError:
payload = {}
error_message = payload.get("error") or payload.get("message") or (response.text or "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")

View File

@ -1,8 +1,13 @@
import json
import logging
import os
from collections.abc import Sequence
from typing import Literal
import httpx
from pydantic import TypeAdapter
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from enums.cloud_plan import CloudPlan
@ -11,6 +16,15 @@ from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models import Account, TenantAccountJoin, TenantAccountRole
logger = logging.getLogger(__name__)
class SubscriptionPlan(TypedDict):
"""Tenant subscriptionplan information."""
plan: str
expiration_date: int
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
@ -18,6 +32,11 @@ class BillingService:
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
# Redis key prefix for tenant plan cache
_PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
# Cache TTL: 10 minutes
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
@ -239,3 +258,135 @@ class BillingService:
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
payload = {"account_id": account_id, "click_id": click_id}
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)
@classmethod
def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Bulk fetch billing subscription plan via billing API.
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
Returns:
Mapping of tenant_id -> {plan: str, expiration_date: int}
"""
results: dict[str, SubscriptionPlan] = {}
subscription_adapter = TypeAdapter(SubscriptionPlan)
chunk_size = 200
for i in range(0, len(tenant_ids), chunk_size):
chunk = tenant_ids[i : i + chunk_size]
try:
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
data = resp.get("data", {})
for tenant_id, plan in data.items():
try:
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
except Exception:
logger.exception(
"get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
)
continue
except Exception:
logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
continue
return results
@classmethod
def _make_plan_cache_key(cls, tenant_id: str) -> str:
return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
@classmethod
def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
NOTE: if you want to high data consistency, use get_plan_bulk instead.
Returns:
Mapping of tenant_id -> {plan: str, expiration_date: int}
"""
tenant_plans: dict[str, SubscriptionPlan] = {}
if not tenant_ids:
return tenant_plans
subscription_adapter = TypeAdapter(SubscriptionPlan)
# Step 1: Batch fetch from Redis cache using mget
redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
try:
cached_values = redis_client.mget(redis_keys)
if len(cached_values) != len(tenant_ids):
raise Exception(
"get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
)
# Map cached values back to tenant_ids
cache_misses: list[str] = []
for tenant_id, cached_value in zip(tenant_ids, cached_values):
if cached_value:
try:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
subscription_plan = subscription_adapter.validate_python(plan_dict)
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(
"get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
)
cache_misses.append(tenant_id)
else:
cache_misses.append(tenant_id)
logger.info(
"get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
len(tenant_plans),
len(cache_misses),
)
except Exception:
logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
cache_misses = list(tenant_ids)
# Step 2: Fetch missing plans from billing API
if cache_misses:
bulk_plans = BillingService.get_plan_bulk(cache_misses)
if bulk_plans:
plans_to_cache: dict[str, SubscriptionPlan] = {}
for tenant_id, subscription_plan in bulk_plans.items():
tenant_plans[tenant_id] = subscription_plan
plans_to_cache[tenant_id] = subscription_plan
# Step 3: Batch update Redis cache using pipeline
if plans_to_cache:
try:
pipe = redis_client.pipeline()
for tenant_id, subscription_plan in plans_to_cache.items():
redis_key = cls._make_plan_cache_key(tenant_id)
# Serialize dict to JSON string
json_str = json.dumps(subscription_plan)
pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
pipe.execute()
logger.info(
"get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
len(plans_to_cache),
)
except Exception:
logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
return tenant_plans
@classmethod
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
data = resp.get("data", [])
tenant_whitelist = []
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist

View File

@ -6,7 +6,9 @@ from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -202,6 +204,7 @@ class ConversationService:
user: Union[Account, EndUser] | None,
limit: int,
last_id: str | None,
variable_name: str | None = None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@ -212,7 +215,25 @@ class ConversationService:
.order_by(ConversationVariable.created_at)
)
with Session(db.engine) as session:
# Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(
func.json_extract(ConversationVariable.data, "$.name").ilike(
f"%{escaped_variable_name}%", escape="\\"
)
)
elif dify_config.DB_TYPE == "postgresql":
stmt = stmt.where(
func.json_extract_path_text(ConversationVariable.data, "name").ilike(
f"%{escaped_variable_name}%", escape="\\"
)
)
with session_factory.create_session() as session:
if last_id:
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
if not last_variable:
@ -279,7 +300,7 @@ class ConversationService:
.where(ConversationVariable.id == variable_id)
)
with Session(db.engine) as session:
with session_factory.create_session() as session:
existing_variable = session.scalar(stmt)
if not existing_variable:
raise ConversationVariableNotExistsError()

View File

@ -3458,7 +3458,7 @@ class SegmentService:
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc())
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total

View File

@ -110,5 +110,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)

View File

@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
yaml_content: str | None = None

View File

@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
pass
class InvokeRateLimitError(Exception):
"""Raised when rate limit is exceeded for workflow invocations."""
class WorkflowQuotaLimitError(Exception):
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
pass

View File

@ -105,3 +105,49 @@ class PluginParameterService:
)
.options
)
@staticmethod
def get_dynamic_select_options_with_credentials(
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
action: str,
parameter: str,
credential_id: str,
credentials: Mapping[str, Any],
) -> Sequence[PluginParameterOption]:
"""
Get dynamic select options using provided credentials directly.
Used for edit mode when credentials have been modified but not yet saved.
Security: credential_id is validated against tenant_id to ensure
users can only access their own credentials.
"""
from constants import HIDDEN_VALUE
# Get original subscription to replace hidden values (with tenant_id check for security)
original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
if not original_subscription:
raise ValueError(f"Subscription {credential_id} not found")
# Replace [__HIDDEN__] with original values
resolved_credentials: dict[str, Any] = {
key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
for key, value in credentials.items()
}
return (
DynamicSelectClient()
.fetch_dynamic_select_options(
tenant_id,
user_id,
plugin_id,
provider,
action,
resolved_credentials,
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
parameter,
)
.options
)

View File

@ -7,7 +7,6 @@ from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
@ -178,9 +177,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -322,9 +318,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -347,9 +340,6 @@ class ApiToolManageService:
db.session.delete(provider)
db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod

View File

@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -205,9 +204,6 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
@ -286,12 +282,10 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
@ -409,9 +403,6 @@ class BuiltinToolManageService:
)
cache.delete()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -434,8 +425,6 @@ class BuiltinToolManageService:
target_provider.is_default = True
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod

View File

@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class ProviderUrlValidationData(BaseModel):
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
current_server_url_hash: str
headers: dict[str, str]
timeout: float | None
sse_read_timeout: float | None
class MCPToolManageService:
"""Service class for managing MCP tools and providers."""
@ -166,9 +174,6 @@ class MCPToolManageService:
self._session.add(mcp_tool)
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
@ -192,7 +197,7 @@ class MCPToolManageService:
Update an MCP provider.
Args:
validation_result: Pre-validation result from validate_server_url_change.
validation_result: Pre-validation result from validate_server_url_standalone.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
@ -251,8 +256,6 @@ class MCPToolManageService:
# Flush changes to database
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)
@ -261,9 +264,6 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
@ -319,8 +319,14 @@ class MCPToolManageService:
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
# Update database with retrieved tools (ensure description is a non-null string)
tools_payload = []
for tool in tools:
data = tool.model_dump()
if data.get("description") is None:
data["description"] = ""
tools_payload.append(data)
db_provider.tools = json.dumps(tools_payload)
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
@ -546,30 +552,39 @@ class MCPToolManageService:
)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
"""
Get provider data required for URL validation.
This method performs database read and should be called within a session.
Returns:
ProviderUrlValidationData: Data needed for standalone URL validation
"""
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = provider.to_entity()
headers = provider_entity.headers
return ProviderUrlValidationData(
current_server_url_hash=provider.server_url_hash,
headers=provider_entity.headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
)
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
@staticmethod
def validate_server_url_standalone(
*,
tenant_id: str,
new_server_url: str,
validation_data: ProviderUrlValidationData,
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
This method performs network operations and MUST be called OUTSIDE of any database session
to avoid holding locks during network I/O.
Args:
tenant_id: Tenant ID for encryption
new_server_url: The new server URL to validate
validation_data: Provider data obtained from get_provider_for_url_validation
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
@ -579,25 +594,30 @@ class MCPToolManageService:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
if not self._is_valid_url(new_server_url):
parsed = urlparse(new_server_url)
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
if new_server_url_hash == validation_data.current_server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
needs_validation=False,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)
# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
# Perform network validation - this is the expensive operation that should be outside session
reconnect_result = MCPToolManageService._reconnect_with_url(
server_url=new_server_url,
headers=validation_data.headers,
timeout=validation_data.timeout,
sse_read_timeout=validation_data.sse_read_timeout,
)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
@ -606,6 +626,60 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
@staticmethod
def reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
return MCPToolManageService._reconnect_with_url(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
@staticmethod
def _reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
"""
Attempt to connect to MCP server with given URL.
This is a static method that performs network I/O without database access.
"""
from core.mcp.mcp_client import MCPClient
try:
with MCPClient(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
# Ensure tool descriptions are non-null in payload
tools_payload = []
for t in tools:
d = t.model_dump()
if d.get("description") is None:
d["description"] = ""
tools_payload.append(d)
return ReconnectResult(
authed=True,
tools=json.dumps(tools_payload),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:

View File

@ -1,6 +1,5 @@
import logging
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
@ -16,14 +15,6 @@ class ToolCommonService:
:return: the list of tool providers
"""
# Try to get from cache first
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
if cached_result is not None:
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
return cached_result
# Cache miss - fetch from database
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
@ -32,7 +23,4 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
# Cache the result
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
return result

View File

@ -7,7 +7,6 @@ from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@ -68,34 +67,31 @@ class WorkflowToolManageService:
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
session.add(workflow_tool_provider)
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
with Session(db.engine, expire_on_commit=False) as session, session.begin():
session.add(workflow_tool_provider)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod
@ -183,9 +179,6 @@ class WorkflowToolManageService:
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod
@ -248,9 +241,6 @@ class WorkflowToolManageService:
db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod

View File

@ -94,16 +94,23 @@ class TriggerProviderService:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for subscription in subscriptions:
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
)
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.properties = dict(
properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
)
subscription.parameters = dict(subscription.parameters)
count = workflows_in_use_map.get(subscription.id)
subscription.workflows_in_use = count if count is not None else 0
@ -209,6 +216,101 @@ class TriggerProviderService:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
@classmethod
def update_trigger_subscription(
cls,
tenant_id: str,
subscription_id: str,
name: str | None = None,
properties: Mapping[str, Any] | None = None,
parameters: Mapping[str, Any] | None = None,
credentials: Mapping[str, Any] | None = None,
credential_expires_at: int | None = None,
expires_at: int | None = None,
) -> None:
"""
Update an existing trigger subscription.
:param tenant_id: Tenant ID
:param subscription_id: Subscription instance ID
:param name: Optional new name for this subscription
:param properties: Optional new properties
:param parameters: Optional new parameters
:param credentials: Optional new credentials
:param credential_expires_at: Optional new credential expiration timestamp
:param expires_at: Optional new expiration timestamp
:return: Success response with updated subscription info
"""
with Session(db.engine, expire_on_commit=False) as session:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Trigger subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Check for name uniqueness if name is being updated
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update properties if provided
if properties is not None:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
# Handle hidden values - preserve original encrypted values
original_properties = properties_encrypter.decrypt(subscription.properties)
new_properties: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
for key, value in properties.items()
}
subscription.properties = dict(properties_encrypter.encrypt(new_properties))
# Update parameters if provided
if parameters is not None:
subscription.parameters = dict(parameters)
# Update credentials if provided
if credentials is not None:
credential_type = CredentialType.of(subscription.credential_type)
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
# Update credential expiration timestamp if provided
if credential_expires_at is not None:
subscription.credential_expires_at = credential_expires_at
# Update expiration timestamp if provided
if expires_at is not None:
subscription.expires_at = expires_at
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
@classmethod
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
"""
@ -257,17 +359,18 @@ class TriggerProviderService:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
if is_auto_created:
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
@ -280,8 +383,8 @@ class TriggerProviderService:
except Exception as e:
logger.exception("Error unsubscribing trigger", exc_info=e)
# Clear cache
session.delete(subscription)
# Clear cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
@ -688,3 +791,188 @@ class TriggerProviderService:
)
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
return subscription
@classmethod
def verify_subscription_credentials(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
credentials: Mapping[str, Any],
) -> dict[str, Any]:
"""
Verify credentials for an existing subscription without updating it.
This is used in edit mode to validate new credentials before rebuild.
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider identifier
:param subscription_id: Subscription ID
:param credentials: New credentials to verify
:return: dict with 'verified' boolean
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription = cls.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
# For API Key, validate the new credentials
if credential_type == CredentialType.API_KEY:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
try:
provider_controller.validate_credentials(user_id, credentials=new_credentials)
return {"verified": True}
except Exception as e:
raise ValueError(f"Invalid credentials: {e}") from e
return {"verified": True}
@classmethod
def rebuild_trigger_subscription(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
credentials: Mapping[str, Any],
parameters: Mapping[str, Any],
name: str | None = None,
) -> None:
"""
Create a subscription builder for rebuilding an existing subscription.
This method creates a builder pre-filled with data from the rebuild request,
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
:param tenant_id: Tenant ID
:param name: Name for the subscription
:param subscription_id: Subscription ID
:param provider_id: Provider identifier
:param credentials: Credentials for the subscription
:param parameters: Parameters for the subscription
:return: SubscriptionBuilderApiEntity
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
with Session(db.engine, expire_on_commit=False) as session:
try:
# Get subscription within the transaction
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
# Decrypt existing credentials for merging
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
merged_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
user_id = subscription.user_id
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented,
# delete the previous subscription and create a new one
# Unsubscribe the previous subscription (external call, but we'll handle errors)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=decrypted_credentials,
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
# Continue anyway - the subscription might already be deleted externally
# Create a new subscription with the same subscription_id and endpoint_id (external call)
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=merged_credentials,
credential_type=credential_type,
)
# Update the subscription in the same transaction
# Inline update logic to reuse the same session
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing and existing.id != subscription.id:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update parameters
subscription.parameters = dict(parameters)
# Update credentials with merged (and encrypted) values
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
# Update properties
if new_subscription.properties:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
# Update expiration timestamp
if new_subscription.expires_at is not None:
subscription.expires_at = new_subscription.expires_at
# Commit the transaction
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
except Exception as e:
# Rollback on any error
session.rollback()
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
raise

View File

@ -453,11 +453,12 @@ class TriggerSubscriptionBuilderService:
if not subscription_builder:
return None
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
)
try:
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id,
provider_id=TriggerProviderID(subscription_builder.provider_id),
)
dispatch_response: TriggerDispatchResponse = controller.dispatch(
request=request,
subscription=subscription_builder.to_subscription(),

View File

@ -863,10 +863,18 @@ class WebhookService:
not_found_in_cache.append(node_id)
continue
with Session(db.engine) as session:
try:
# lock the concurrent webhook trigger creation
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
lock = redis_client.lock(lock_key, timeout=10)
lock_acquired = False
try:
# acquire the lock with blocking and timeout
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
if not lock_acquired:
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
with Session(db.engine) as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@ -903,11 +911,16 @@ class WebhookService:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
# release the lock only if it was acquired
if lock_acquired:
try:
lock.release()
except Exception:
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
@classmethod
def generate_webhook_id(cls) -> str: