Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -2,7 +2,7 @@ import logging
from collections.abc import Mapping
from enum import StrEnum
from threading import Lock
from typing import Any, Optional
from typing import Any
from httpx import Timeout, post
from pydantic import BaseModel
@ -24,8 +24,8 @@ class CodeExecutionError(Exception):
class CodeExecutionResponse(BaseModel):
class Data(BaseModel):
stdout: Optional[str] = None
error: Optional[str] = None
stdout: str | None = None
error: str | None = None
code: int
message: str

View File

@ -1,9 +1,33 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TypedDict
from pydantic import BaseModel
class CodeNodeProvider(BaseModel):
class VariableConfig(TypedDict):
variable: str
value_selector: Sequence[str | int]
class OutputConfig(TypedDict):
type: str
children: None
class CodeConfig(TypedDict):
variables: Sequence[VariableConfig]
code_language: str
code: str
outputs: Mapping[str, OutputConfig]
class DefaultConfig(TypedDict):
type: str
config: CodeConfig
class CodeNodeProvider(BaseModel, ABC):
@staticmethod
@abstractmethod
def get_language() -> str:
@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel):
pass
@classmethod
def get_default_config(cls) -> dict:
def get_default_config(cls) -> DefaultConfig:
return {
"type": "code",
"config": {
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {"result": {"type": "string", "children": None}},

View File

@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod
def transform_response(cls, response: str) -> dict:
def transform_response(cls, response: str):
"""
Transform response to dict
:param response: response

View File

@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str:
return dedent(
"""
def main(arg1: str, arg2: str) -> dict:
def main(arg1: str, arg2: str):
return {
"result": arg1 + arg2,
}

View File

@ -0,0 +1,75 @@
"""
Credential utility functions for checking credential existence and policy compliance.
"""
from services.enterprise.plugin_manager_service import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
"""
Check if the credential still exists in the database.
:param credential_id: The credential ID to check
:param credential_type: The type of credential (MODEL or TOOL)
:return: True if credential exists, False otherwise
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.provider import ProviderCredential, ProviderModelCredential
from models.tools import BuiltinToolProvider
with Session(db.engine) as session:
if credential_type == PluginCredentialType.MODEL:
# Check both pre-defined and custom model credentials using a single UNION query
stmt = (
select(ProviderCredential.id)
.where(ProviderCredential.id == credential_id)
.union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id))
)
return session.scalar(stmt) is not None
if credential_type == PluginCredentialType.TOOL:
return (
session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id))
is not None
)
return False
def check_credential_policy_compliance(
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
) -> None:
"""
Check credential policy compliance for the given credential ID.
:param credential_id: The credential ID to check
:param provider: The provider name
:param credential_type: The type of credential (MODEL or TOOL)
:param check_existence: Whether to check if credential exists in database first
:raises ValueError: If credential policy compliance check fails
"""
from services.enterprise.plugin_manager_service import (
CheckCredentialPolicyComplianceRequest,
PluginManagerService,
)
from services.feature_service import FeatureService
if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id:
return
# Check if credential exists in database first (if requested)
if check_existence:
if not is_credential_exists(credential_id, credential_type):
raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.")
# Check policy compliance
PluginManagerService.check_credential_policy_compliance(
CheckCredentialPolicyComplianceRequest(
dify_credential_id=credential_id,
provider=provider,
credential_type=credential_type,
)
)

View File

@ -11,9 +11,13 @@ def obfuscated_token(token: str) -> str:
return token[:6] + "*" * 12 + token[-2:]
def full_mask_token(token_length=20):
return "*" * token_length
def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")

View File

@ -42,7 +42,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
except Exception:
pass
return result

View File

@ -1,12 +1,11 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCacheType(StrEnum):
PROVIDER = "provider"
MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
@ -14,9 +13,9 @@ class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
def get(self) -> dict | None:
"""
Get cached model provider credentials.
@ -34,7 +33,7 @@ class ProviderCredentialsCache:
else:
return None
def set(self, credentials: dict) -> None:
def set(self, credentials: dict):
"""
Cache model provider credentials.
@ -43,7 +42,7 @@ class ProviderCredentialsCache:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
def delete(self):
"""
Delete cached model provider credentials.

View File

@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
) -> type:
"""
Load a single subclass from the source

View File

@ -0,0 +1,42 @@
import logging
import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
logger = logging.getLogger(__name__)
def generate_provider_name(
providers: Sequence[Any], credential_type: CredentialType, fallback_context: str = "provider"
) -> str:
try:
return generate_incremental_name(
[provider.name for provider in providers],
f"{credential_type.get_name()}",
)
except Exception as e:
logger.warning("Error generating next provider name for %r: %r", fallback_context, e)
return f"{credential_type.get_name()} 1"
def generate_incremental_name(
names: Sequence[str],
default_pattern: str,
) -> str:
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for name in names:
if not name:
continue
match = re.match(pattern, name.strip())
if match:
numbers.append(int(match.group(1)))
if not numbers:
return f"{default_pattern} 1"
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"

View File

@ -1,12 +1,14 @@
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from functools import lru_cache
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
from core.tools.utils.yaml_utils import load_yaml_file_cached
@lru_cache(maxsize=128)
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
@ -14,12 +16,17 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
# FIXME(-LAN-): Cache position maps to prevent file descriptor exhaustion during high-load benchmarks
position_file_path = os.path.join(folder_path, file_name)
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
try:
yaml_content = load_yaml_file_cached(file_path=position_file_path)
except Exception:
yaml_content = []
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
return {name: index for index, name in enumerate(positions)}
@lru_cache(maxsize=128)
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
@ -35,20 +42,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.
@ -72,11 +65,14 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map
T = TypeVar("T")
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
data: T,
name_func: Callable[[T], str],
) -> bool:
"""
Check if the object should be filtered out.
@ -103,9 +99,9 @@ def is_filtered(
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
data: list[T],
name_func: Callable[[T], str],
):
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@ -122,9 +118,9 @@ def sort_by_position_map(
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
data: list[T],
name_func: Callable[[T], str],
):
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
@ -134,4 +130,4 @@ def sort_to_dict_by_position_map(
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])
return OrderedDict((name_func(item), item) for item in sorted_items)

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any, Optional
from typing import Any
from extensions.ext_redis import redis_client
@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> Optional[dict]:
def get(self) -> dict | None:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC):
return None
return None
def set(self, config: dict[str, Any]) -> None:
def set(self, config: dict[str, Any]):
"""Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None:
def delete(self):
"""Delete cached provider credentials"""
redis_client.delete(self.cache_key)
@ -71,14 +71,14 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> Optional[dict]:
def get(self) -> dict | None:
"""Get cached provider credentials"""
return None
def set(self, config: dict[str, Any]) -> None:
def set(self, config: dict[str, Any]):
"""Cache provider credentials"""
pass
def delete(self) -> None:
def delete(self):
"""Delete cached provider credentials"""
pass

View File

@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(config_value).lower()
if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True
http_request_node_ssl_verify = True
elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False
http_request_node_ssl_verify = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True
http_request_node_ssl_verify = True
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
kwargs["ssl_verify"] = http_request_node_ssl_verify
ssl_verify = kwargs.pop("ssl_verify")

View File

@ -1,12 +1,11 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
class ToolParameterCacheType(StrEnum):
PARAMETER = "tool_parameter"
@ -15,11 +14,11 @@ class ToolParameterCache:
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = (
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f":identity_id:{identity_id}"
)
def get(self) -> Optional[dict]:
def get(self) -> dict | None:
"""
Get cached model provider credentials.
@ -37,11 +36,11 @@ class ToolParameterCache:
else:
return None
def set(self, parameters: dict) -> None:
def set(self, parameters: dict):
"""Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
def delete(self):
"""
Delete cached model provider credentials.

View File

@ -1,7 +1,7 @@
import contextlib
import re
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
def is_valid_trace_id(trace_id: str) -> bool:
@ -13,7 +13,7 @@ def is_valid_trace_id(trace_id: str) -> bool:
return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id))
def get_external_trace_id(request: Any) -> Optional[str]:
def get_external_trace_id(request: Any) -> str | None:
"""
Retrieve the trace_id from the request.
@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]:
return None
def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
def extract_external_trace_id_from_args(args: Mapping[str, Any]):
"""
Extract 'external_trace_id' from args.
@ -61,7 +61,7 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
return {}
def get_trace_id_from_otel_context() -> Optional[str]:
def get_trace_id_from_otel_context() -> str | None:
"""
Retrieve the current trace ID from the active OpenTelemetry trace context.
Returns None if:
@ -88,7 +88,7 @@ def get_trace_id_from_otel_context() -> Optional[str]:
return None
def parse_traceparent_header(traceparent: str) -> Optional[str]:
def parse_traceparent_header(traceparent: str) -> str | None:
"""
Parse the `traceparent` header to extract the trace_id.