Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN-
2025-09-09 17:11:46 +08:00
43 changed files with 945 additions and 310 deletions

View File

@ -42,6 +42,7 @@ from models.provider import (
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)
@ -129,14 +130,38 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
else:
credentials = None
current_credential_id = None
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
current_credential_id = model_configuration.current_credential_id
break
if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
current_credential_id = self.custom_configuration.provider.current_credential_id
if current_credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=current_credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
else:
# no current credential id, check all available credentials
if self.custom_configuration.provider:
for credential_configuration in self.custom_configuration.provider.available_credentials:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=credential_configuration.credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
return credentials
@ -266,7 +291,6 @@ class ProviderConfiguration(BaseModel):
:param credential_id: if provided, return the specified credential
:return:
"""
if credential_id:
return self._get_specific_provider_credential(credential_id)
@ -739,6 +763,7 @@ class ProviderConfiguration(BaseModel):
current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name
credentials = self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
@ -793,6 +818,7 @@ class ProviderConfiguration(BaseModel):
):
current_credential_id = model_configuration.current_credential_id
current_credential_name = model_configuration.current_credential_name
credentials = self.obfuscated_credentials(
credentials=model_configuration.credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas

View File

@ -145,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
name: str
credentials: dict
credential_source_type: str | None = None
credential_id: str | None = None
class ModelSettings(BaseModel):

View File

@ -0,0 +1,75 @@
"""
Credential utility functions for checking credential existence and policy compliance.
"""
from services.enterprise.plugin_manager_service import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
"""
Check if the credential still exists in the database.
:param credential_id: The credential ID to check
:param credential_type: The type of credential (MODEL or TOOL)
:return: True if credential exists, False otherwise
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.provider import ProviderCredential, ProviderModelCredential
from models.tools import BuiltinToolProvider
with Session(db.engine) as session:
if credential_type == PluginCredentialType.MODEL:
# Check both pre-defined and custom model credentials using a single UNION query
stmt = (
select(ProviderCredential.id)
.where(ProviderCredential.id == credential_id)
.union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id))
)
return session.scalar(stmt) is not None
if credential_type == PluginCredentialType.TOOL:
return (
session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id))
is not None
)
return False
def check_credential_policy_compliance(
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
) -> None:
"""
Check credential policy compliance for the given credential ID.
:param credential_id: The credential ID to check
:param provider: The provider name
:param credential_type: The type of credential (MODEL or TOOL)
:param check_existence: Whether to check if credential exists in database first
:raises ValueError: If credential policy compliance check fails
"""
from services.enterprise.plugin_manager_service import (
CheckCredentialPolicyComplianceRequest,
PluginManagerService,
)
from services.feature_service import FeatureService
if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id:
return
# Check if credential exists in database first (if requested)
if check_existence:
if not is_credential_exists(credential_id, credential_type):
raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.")
# Check policy compliance
PluginManagerService.check_credential_policy_compliance(
CheckCredentialPolicyComplianceRequest(
dify_credential_id=credential_id,
provider=provider,
credential_type=credential_type,
)
)

View File

@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)
@ -362,6 +363,23 @@ class ModelInstance:
else:
raise last_exception
# Additional policy compliance check as fallback (in case fetch_next didn't catch it)
try:
from core.helper.credential_utils import check_credential_policy_compliance
if lb_config.credential_id:
check_credential_policy_compliance(
credential_id=lb_config.credential_id,
provider=self.provider,
credential_type=PluginCredentialType.MODEL,
)
except Exception as e:
logger.warning(
"Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e)
)
self.load_balancing_manager.cooldown(lb_config, expire=60)
continue
try:
if "credentials" in kwargs:
del kwargs["credentials"]
@ -515,6 +533,24 @@ class LBModelManager:
continue
# Check policy compliance for the selected configuration
try:
from core.helper.credential_utils import check_credential_policy_compliance
if config.credential_id:
check_credential_policy_compliance(
credential_id=config.credential_id,
provider=self._provider,
credential_type=PluginCredentialType.MODEL,
)
except Exception as e:
logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e))
cooldown_load_balancing_configs.append(config)
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
# all configs are in cooldown or failed policy compliance
return None
continue
if dify_config.DEBUG:
logger.info(
"""Model LB

View File

@ -1129,6 +1129,7 @@ class ProviderManager:
name=load_balancing_model_config.name,
credentials=provider_model_credentials,
credential_source_type=load_balancing_model_config.credential_source_type,
credential_id=load_balancing_model_config.credential_id,
)
)

View File

@ -1,8 +1,9 @@
import json
import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Optional
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel):
return values
def ensure_client(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVector(BaseVector):
"""
Matrixone vector storage implementation.
@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector):
self.client.delete()
T = TypeVar("T", bound=MatrixoneVector)
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:

View File

@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass
class ToolCredentialPolicyViolationError(ValueError):
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@ -21,6 +21,7 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -44,16 +45,16 @@ from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tools_transform_service import ToolTransformService
@ -115,7 +116,6 @@ class ToolManager:
get the plugin provider
"""
# check if context is set
from core.plugin.impl.tool import PluginToolManager
try:
contexts.plugin_tool_providers.get()
@ -237,6 +237,16 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# check if the credential is allowed to be used
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
@ -509,7 +519,6 @@ class ToolManager:
"""
list all the plugin providers
"""
from core.plugin.impl.tool import PluginToolManager
manager = PluginToolManager()
provider_entities = manager.fetch_tool_providers(tenant_id)