mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge branch 'main' into feat/memory-orchestration-be
# Conflicts: # api/core/app/apps/advanced_chat/app_runner.py # api/core/prompt/entities/advanced_prompt_entities.py # api/core/variables/segments.py
This commit is contained in:
@ -2,10 +2,10 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
@ -37,23 +37,15 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .chatflow_memory_service import ChatflowMemoryService
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
from .workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
@ -89,20 +81,21 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
def is_workflow_exist(self, app_model: App) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
stmt = select(
|
||||
exists().where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.count()
|
||||
) > 0
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
if workflow_id:
|
||||
return self.get_published_workflow_by_id(app_model, workflow_id)
|
||||
# fetch draft workflow by app_model
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
@ -117,8 +110,10 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
# fetch published workflow by workflow_id
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
"""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
@ -137,7 +132,7 @@ class WorkflowService:
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
def get_published_workflow(self, app_model: App) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
@ -202,7 +197,7 @@ class WorkflowService:
|
||||
app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
@ -270,6 +265,12 @@ class WorkflowService:
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# Validate credentials before publishing, for credential policy check
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if FeatureService.get_system_features().plugin_manager.enabled:
|
||||
self._validate_workflow_credentials(draft_workflow)
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@ -294,6 +295,260 @@ class WorkflowService:
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
def _validate_workflow_credentials(self, workflow: Workflow) -> None:
|
||||
"""
|
||||
Validate all credentials in workflow nodes before publishing.
|
||||
|
||||
:param workflow: The workflow to validate
|
||||
:raises ValueError: If any credentials violate policy compliance
|
||||
"""
|
||||
graph_dict = workflow.graph_dict
|
||||
nodes = graph_dict.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
node_type = node_data.get("type")
|
||||
node_id = node.get("id", "unknown")
|
||||
|
||||
try:
|
||||
# Extract and validate credentials based on node type
|
||||
if node_type == "tool":
|
||||
credential_id = node_data.get("credential_id")
|
||||
provider = node_data.get("provider_id")
|
||||
if provider:
|
||||
if credential_id:
|
||||
# Check specific credential
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=credential_id,
|
||||
provider=provider,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
)
|
||||
else:
|
||||
# Check default workspace credential for this provider
|
||||
self._check_default_tool_credential(workflow.tenant_id, provider)
|
||||
|
||||
elif node_type == "agent":
|
||||
agent_params = node_data.get("agent_parameters", {})
|
||||
|
||||
model_config = agent_params.get("model", {}).get("value", {})
|
||||
if model_config.get("provider") and model_config.get("model"):
|
||||
self._validate_llm_model_config(
|
||||
workflow.tenant_id, model_config["provider"], model_config["model"]
|
||||
)
|
||||
|
||||
# Validate load balancing credentials for agent model if load balancing is enabled
|
||||
agent_model_node_data = {"model": model_config}
|
||||
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
|
||||
|
||||
# Validate agent tools
|
||||
tools = agent_params.get("tools", {}).get("value", [])
|
||||
for tool in tools:
|
||||
# Agent tools store provider in provider_name field
|
||||
provider = tool.get("provider_name")
|
||||
credential_id = tool.get("credential_id")
|
||||
if provider:
|
||||
if credential_id:
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
|
||||
else:
|
||||
self._check_default_tool_credential(workflow.tenant_id, provider)
|
||||
|
||||
elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
|
||||
model_config = node_data.get("model", {})
|
||||
provider = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
|
||||
if provider and model_name:
|
||||
# Validate that the provider+model combination can fetch valid credentials
|
||||
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
|
||||
# Validate load balancing credentials if load balancing is enabled
|
||||
self._validate_load_balancing_credentials(workflow, node_data, node_id)
|
||||
else:
|
||||
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise e
|
||||
else:
|
||||
raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")
|
||||
|
||||
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
|
||||
"""
|
||||
Validate that an LLM model configuration can fetch valid credentials.
|
||||
|
||||
This method attempts to get the model instance and validates that:
|
||||
1. The provider exists and is configured
|
||||
2. The model exists in the provider
|
||||
3. Credentials can be fetched for the model
|
||||
4. The credentials pass policy compliance checks
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
:param model_name: The model name
|
||||
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
|
||||
"""
|
||||
try:
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
# Get model instance to validate provider+model combination
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
|
||||
# The ModelInstance constructor will automatically check credential policy compliance
|
||||
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
|
||||
# If it fails, an exception will be raised
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
|
||||
)
|
||||
|
||||
def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
|
||||
"""
|
||||
Check credential policy compliance for the default workspace credential of a tool provider.
|
||||
|
||||
This method finds the default credential for the given provider and validates it.
|
||||
Uses the same fallback logic as runtime to handle deauthorized credentials.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The tool provider name
|
||||
:raises ValueError: If no default credential exists or if it fails policy compliance
|
||||
"""
|
||||
try:
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
# Use the same fallback logic as runtime: get the first available credential
|
||||
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
|
||||
default_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
raise ValueError("No default credential found")
|
||||
|
||||
# Check credential policy compliance using the default credential ID
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=default_provider.id,
|
||||
provider=provider,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
|
||||
|
||||
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
|
||||
"""
|
||||
Validate load balancing credentials for a workflow node.
|
||||
|
||||
:param workflow: The workflow being validated
|
||||
:param node_data: The node data containing model configuration
|
||||
:param node_id: The node ID for error reporting
|
||||
:raises ValueError: If load balancing credentials violate policy compliance
|
||||
"""
|
||||
# Extract model configuration
|
||||
model_config = node_data.get("model", {})
|
||||
provider = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
|
||||
if not provider or not model_name:
|
||||
return # No model config to validate
|
||||
|
||||
# Check if this model has load balancing enabled
|
||||
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
|
||||
# Get all load balancing configurations for this model
|
||||
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
|
||||
# Validate each load balancing configuration
|
||||
try:
|
||||
for config in load_balancing_configs:
|
||||
if config.get("credential_id"):
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
config["credential_id"], provider, PluginCredentialType.MODEL
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
|
||||
|
||||
def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
|
||||
"""
|
||||
Check if load balancing is enabled for a specific model.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
:param model_name: The model name
|
||||
:return: True if load balancing is enabled, False otherwise
|
||||
"""
|
||||
try:
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
# Get provider configurations
|
||||
provider_manager = ProviderManager()
|
||||
provider_configurations = provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
||||
if not provider_configuration:
|
||||
return False
|
||||
|
||||
# Get provider model setting
|
||||
provider_model_setting = provider_configuration.get_provider_model_setting(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
return provider_model_setting is not None and provider_model_setting.load_balancing_enabled
|
||||
|
||||
except Exception:
|
||||
# If we can't determine the status, assume load balancing is not enabled
|
||||
return False
|
||||
|
||||
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
|
||||
"""
|
||||
Get all load balancing configurations for a model.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
:param model_name: The model name
|
||||
:return: List of load balancing configuration dictionaries
|
||||
"""
|
||||
try:
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
_, configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
model_type="llm", # Load balancing is primarily used for LLM models
|
||||
config_from="predefined-model", # Check both predefined and custom models
|
||||
)
|
||||
|
||||
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
|
||||
)
|
||||
all_configs = configs + custom_configs
|
||||
|
||||
return [config for config in all_configs if config.get("credential_id")]
|
||||
|
||||
except Exception:
|
||||
# If we can't get the configurations, return empty list
|
||||
# This will prevent validation errors from breaking the workflow
|
||||
return []
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configs
|
||||
@ -308,7 +563,7 @@ class WorkflowService:
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
@ -509,10 +764,10 @@ class WorkflowService:
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node = e._node
|
||||
node = e.node
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e._error
|
||||
error = e.error
|
||||
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = WorkflowNodeExecution(
|
||||
@ -568,7 +823,7 @@ class WorkflowService:
|
||||
# chatbot convert to workflow mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
|
||||
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
@ -583,12 +838,12 @@ class WorkflowService:
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
@ -597,7 +852,7 @@ class WorkflowService:
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
|
||||
Reference in New Issue
Block a user