diff --git a/api/.importlinter b/api/.importlinter index 5c0a6e1288..8dffc3506b 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -43,7 +43,6 @@ forbidden_modules = extensions.ext_redis allow_indirect_imports = True ignore_imports = - dify_graph.nodes.agent.agent_node -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis @@ -90,9 +89,6 @@ forbidden_modules = core.trigger core.variables ignore_imports = - dify_graph.nodes.agent.agent_node -> core.model_manager - dify_graph.nodes.agent.agent_node -> core.provider_manager - dify_graph.nodes.agent.agent_node -> core.tools.tool_manager dify_graph.nodes.llm.llm_utils -> core.model_manager dify_graph.nodes.llm.protocols -> core.model_manager dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model @@ -100,8 +96,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler dify_graph.nodes.tool.tool_node -> core.tools.tool_engine dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.agent.agent_node -> core.agent.entities - dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform @@ -110,12 +104,10 @@ ignore_imports = dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.agent.agent_node -> models.model dify_graph.nodes.llm.node -> core.helper.code_executor dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util @@ -126,15 +118,11 @@ ignore_imports = dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer dify_graph.nodes.llm.file_saver -> core.tools.signature dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.agent.agent_node -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.agent.agent_node -> models dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.agent.agent_node -> services dify_graph.nodes.tool.tool_node -> services dify_graph.model_runtime.model_providers.__base.ai_model -> configs dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index dd982b6d7b..2025048e09 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,5 +1,4 @@ import json -from enum import StrEnum from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field @@ -11,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm from extensions.ext_database import db from fields.app_fields import app_server_fields from libs.login import current_account_with_tenant, login_required +from models.enums import AppMCPServerStatus from models.model import AppMCPServer DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -19,11 +19,6 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" app_server_model = console_ns.model("AppServer", app_server_fields) -class AppMCPServerStatus(StrEnum): - ACTIVE = "active" - INACTIVE = "inactive" - - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict = Field(..., description="Server parameters configuration") @@ -117,9 +112,10 @@ class AppMCPServerController(Resource): server.parameters = json.dumps(payload.parameters, ensure_ascii=False) if payload.status: - if payload.status not in [status.value for status in AppMCPServerStatus]: + try: + server.status = AppMCPServerStatus(payload.status) + except ValueError: raise ValueError("Invalid status") - server.status = payload.status db.session.commit() return server diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 708df62642..0d8960c9bd 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode +from models.account import AccountStatus, InvitationCodeStatus from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -215,7 +216,7 @@ class AccountInitApi(Resource): db.session.query(InvitationCode) .where( InvitationCode.code == args.invitation_code, - InvitationCode.status == "unused", + InvitationCode.status == InvitationCodeStatus.UNUSED, ) .first() ) @@ -223,7 +224,7 @@ class AccountInitApi(Resource): if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = "used" + invitation_code.status = InvitationCodeStatus.USED invitation_code.used_at = naive_utc_now() invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -231,7 +232,7 @@ class AccountInitApi(Resource): account.interface_language = args.interface_language account.timezone = args.timezone account.interface_theme = "light" - account.status = "active" + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 2f06f72f29..ee537367c7 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -5,6 +5,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource from pydantic import BaseModel, Field +from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden from configs import dify_config @@ -169,6 +170,20 @@ register_enum_models( ) +def _read_upload_content(file: FileStorage, max_size: int) -> bytes: + """ + Read the uploaded file and validate its actual size before delegating to the plugin service. + + FileStorage.content_length is not reliable for multipart test uploads and may be zero even when + content exists, so the controllers validate against the loaded bytes instead. + """ + content = file.read() + if len(content) > max_size: + raise ValueError("File size exceeds the maximum allowed size") + + return content + + @console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @setup_required @@ -284,12 +299,7 @@ class PluginUploadFromPkgApi(Resource): _, tenant_id = current_account_with_tenant() file = request.files["pkg"] - - # check file size - if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE: - raise ValueError("File size exceeds the maximum allowed size") - - content = file.read() + content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE) try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: @@ -328,12 +338,7 @@ class PluginUploadFromBundleApi(Resource): _, tenant_id = current_account_with_tenant() file = request.files["bundle"] - - # check file size - if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE: - raise ValueError("File size exceeds the maximum allowed size") - - content = file.read() + content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE) try: response = PluginService.upload_bundle(tenant_id, content) except PluginDaemonClientSideError as e: diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 2bc6640807..9ddaaa315b 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -6,13 +6,13 @@ from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model -from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db from libs import helper +from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index c6ecd5509b..9271ed10bd 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -6,6 +6,7 @@ from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit +from core.agent.errors import AgentMaxIterationError from core.agent.output_parser.cot_output_parser import CotAgentOutputParser from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent @@ -22,7 +23,6 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/errors.py b/api/core/agent/errors.py new file mode 100644 index 0000000000..ed504d500a --- /dev/null +++ b/api/core/agent/errors.py @@ -0,0 +1,9 @@ +class AgentMaxIterationError(Exception): + """Raised when an agent runner exceeds the configured max iteration count.""" + + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + super().__init__( + f"Agent exceeded the maximum iteration limit of {max_iteration}. " + f"The agent was unable to complete the task within the allowed number of iterations." + ) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3271fe319b..5e13a13b21 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -5,6 +5,7 @@ from copy import deepcopy from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform @@ -25,7 +26,6 @@ from dify_graph.model_runtime.entities import ( UserPromptMessage, ) from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 3812fac28a..8986164fe7 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,7 +3,10 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from pydantic import ValidationError + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.queue_entities import ( AppQueueEvent, @@ -30,6 +33,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_resolution import resolve_workflow_node_class from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -63,7 +67,6 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool @@ -308,7 +311,7 @@ class WorkflowBasedAppRunner: # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool @@ -336,6 +339,18 @@ class WorkflowBasedAppRunner: return graph, variable_pool + @staticmethod + def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None: + raw_agent_strategy = event.extras.get("agent_strategy") + if raw_agent_strategy is None: + return None + + try: + return AgentStrategyInfo.model_validate(raw_agent_strategy) + except ValidationError: + logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True) + return None + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event @@ -421,7 +436,7 @@ class WorkflowBasedAppRunner: start_at=event.start_at, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, - agent_strategy=event.agent_strategy, + agent_strategy=self._build_agent_strategy_info(event), provider_type=event.provider_type, provider_id=event.provider_id, ) diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py index e69de29bb2..8e41acee32 100644 --- a/api/core/app/entities/__init__.py +++ b/api/core/app/entities/__init__.py @@ -0,0 +1,3 @@ +from .agent_strategy import AgentStrategyInfo + +__all__ = ["AgentStrategyInfo"] diff --git a/api/core/app/entities/agent_strategy.py b/api/core/app/entities/agent_strategy.py new file mode 100644 index 0000000000..b063a12f4f --- /dev/null +++ b/api/core/app/entities/agent_strategy.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, ConfigDict + + +class AgentStrategyInfo(BaseModel): + name: str + icon: str | None = None + + model_config = ConfigDict(extra="forbid") diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d42df0d1bf..2d1508f0cb 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -5,8 +5,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field +from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities import AgentNodeStrategyInit from dify_graph.entities.pause_reason import PauseReason from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import WorkflowNodeExecutionMetadataKey @@ -314,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: str | None = None in_loop_id: str | None = None start_at: datetime - agent_strategy: AgentNodeStrategyInit | None = None + agent_strategy: AgentStrategyInfo | None = None # FIXME(-LAN-): only for ToolNode, need to refactor provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index b58dae0ff2..46a8ab52f2 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -4,8 +4,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field +from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities import AgentNodeStrategyInit from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse): extras: dict[str, object] = Field(default_factory=dict) iteration_id: str | None = None loop_id: str | None = None - agent_strategy: AgentNodeStrategyInit | None = None + agent_strategy: AgentStrategyInfo | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..b054409681 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -12,6 +12,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import CreatorUserRole _logger = logging.getLogger(__name__) @@ -38,7 +39,9 @@ class DatasetIndexToolCallbackHandler: source="app", source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + CreatorUserRole.ACCOUNT + if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER ), created_by=self._user_id, ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 33782e7949..9ac753240b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -628,10 +628,10 @@ class TraceTask: if not message_data: return {} conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) - conversation_mode = db.session.scalars(conversation_mode_stmt).all() - if not conversation_mode or len(conversation_mode) == 0: + conversation_modes = db.session.scalars(conversation_mode_stmt).all() + if not conversation_modes or len(conversation_modes) == 0: return {} - conversation_mode = conversation_mode[0] + conversation_mode = conversation_modes[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index f82c3a846b..c29a463bb6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -627,7 +627,7 @@ class ProviderManager: tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, quota_type=quota.quota_type, quota_limit=0, # type: ignore quota_used=0, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8243170c62..fcd3cceb59 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -83,6 +83,7 @@ from models.dataset import ( ) from models.dataset import Document as DatasetDocument from models.dataset import Document as DocumentModel +from models.enums import CreatorUserRole from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService @@ -1009,7 +1010,7 @@ class DatasetRetrieval: content=json.dumps(contents), source="app", source_app_id=app_id, - created_by_role=user_from, + created_by_role=CreatorUserRole(user_from), created_by=user_id, ) dataset_queries.append(dataset_query) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 770df8b050..55e96515ac 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # No sequence number generation needed anymore - db_model.type = domain_model.workflow_type + from models.workflow import WorkflowType as ModelWorkflowType + + db_model.type = ModelWorkflowType(domain_model.workflow_type.value) db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index d7f2a67c06..bc4e0eda71 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -22,6 +22,13 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.summary_index.summary_index import SummaryIndex from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager +from core.workflow.node_resolution import resolve_workflow_node_class +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from core.workflow.nodes.agent.plugin_strategy_adapter import ( + PluginAgentStrategyPresentationProvider, + PluginAgentStrategyResolver, +) +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY @@ -39,7 +46,6 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig from dify_graph.nodes.http_request import build_http_request_config from dify_graph.nodes.llm.entities import LLMNodeData from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData from dify_graph.nodes.template_transform.template_renderer import ( @@ -97,10 +103,7 @@ class DefaultWorkflowCodeExecutor: @final class DifyNodeFactory(NodeFactory): """ - Default implementation of NodeFactory that uses the traditional node mapping. - - This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING - and instantiating the appropriate node class. + Default implementation of NodeFactory that resolves node classes from the live registry. """ def __init__( @@ -143,6 +146,10 @@ class DifyNodeFactory(NodeFactory): ) self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._agent_strategy_resolver = PluginAgentStrategyResolver() + self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() + self._agent_runtime_support = AgentRuntimeSupport() + self._agent_message_transformer = AgentMessageTransformer() @staticmethod def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext: @@ -219,6 +226,12 @@ class DifyNodeFactory(NodeFactory): NodeType.TOOL: lambda: { "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), }, + NodeType.AGENT: lambda: { + "strategy_resolver": self._agent_strategy_resolver, + "presentation_provider": self._agent_strategy_presentation_provider, + "runtime_support": self._agent_runtime_support, + "message_transformer": self._agent_message_transformer, + }, } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( @@ -238,16 +251,7 @@ class DifyNodeFactory(NodeFactory): @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: - node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - latest_node_class = node_mapping.get(LATEST_VERSION) - matched_node_class = node_mapping.get(node_version) - node_class = matched_node_class or latest_node_class - if not node_class: - raise ValueError(f"No latest version class found for node type: {node_type}") - return node_class + return resolve_workflow_node_class(node_type=node_type, node_version=node_version) def _build_llm_compatible_node_init_kwargs( self, diff --git a/api/core/workflow/node_resolution.py b/api/core/workflow/node_resolution.py new file mode 100644 index 0000000000..b922c28165 --- /dev/null +++ b/api/core/workflow/node_resolution.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from collections.abc import Mapping +from importlib import import_module + +from dify_graph.enums import NodeType +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.node_mapping import LATEST_VERSION, get_node_type_classes_mapping + +_WORKFLOW_NODE_MODULES = ("core.workflow.nodes.agent",) +_workflow_nodes_registered = False + + +def ensure_workflow_nodes_registered() -> None: + """Import workflow-local node modules so they can register with `Node.__init_subclass__`.""" + global _workflow_nodes_registered + + if _workflow_nodes_registered: + return + + for module_name in _WORKFLOW_NODE_MODULES: + import_module(module_name) + + _workflow_nodes_registered = True + + +def get_workflow_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: + ensure_workflow_nodes_registered() + return get_node_type_classes_mapping() + + +def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + node_mapping = get_workflow_node_type_classes_mapping().get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/agent/__init__.py b/api/core/workflow/nodes/agent/__init__.py new file mode 100644 index 0000000000..ba6c667194 --- /dev/null +++ b/api/core/workflow/nodes/agent/__init__.py @@ -0,0 +1,4 @@ +from .agent_node import AgentNode +from .entities import AgentNodeData + +__all__ = ["AgentNode", "AgentNodeData"] diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py new file mode 100644 index 0000000000..c1b423d69d --- /dev/null +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any + +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import NodeType, SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser + +from .entities import AgentNodeData +from .exceptions import ( + AgentInvocationError, + AgentMessageTransformError, +) +from .message_transformer import AgentMessageTransformer +from .runtime_support import AgentRuntimeSupport +from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver + +if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState + + +class AgentNode(Node[AgentNodeData]): + node_type = NodeType.AGENT + + _strategy_resolver: AgentStrategyResolver + _presentation_provider: AgentStrategyPresentationProvider + _runtime_support: AgentRuntimeSupport + _message_transformer: AgentMessageTransformer + + def __init__( + self, + id: str, + config: NodeConfigDict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + *, + strategy_resolver: AgentStrategyResolver, + presentation_provider: AgentStrategyPresentationProvider, + runtime_support: AgentRuntimeSupport, + message_transformer: AgentMessageTransformer, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._strategy_resolver = strategy_resolver + self._presentation_provider = presentation_provider + self._runtime_support = runtime_support + self._message_transformer = message_transformer + + @classmethod + def version(cls) -> str: + return "1" + + def populate_start_event(self, event) -> None: + dify_ctx = self.require_dify_context() + event.extras["agent_strategy"] = { + "name": self.node_data.agent_strategy_name, + "icon": self._presentation_provider.get_icon( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + ), + } + + def _run(self) -> Generator[NodeEventBase, None, None]: + from core.plugin.impl.exc import PluginDaemonClientSideError + + dify_ctx = self.require_dify_context() + + try: + strategy = self._strategy_resolver.resolve( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + agent_strategy_name=self.node_data.agent_strategy_name, + ) + except Exception as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error=f"Failed to get agent strategy: {str(e)}", + ), + ) + return + + agent_parameters = strategy.get_parameters() + + parameters = self._runtime_support.build_parameters( + agent_parameters=agent_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + strategy=strategy, + tenant_id=dify_ctx.tenant_id, + app_id=dify_ctx.app_id, + invoke_from=dify_ctx.invoke_from, + ) + parameters_for_log = self._runtime_support.build_parameters( + agent_parameters=agent_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + strategy=strategy, + tenant_id=dify_ctx.tenant_id, + app_id=dify_ctx.app_id, + invoke_from=dify_ctx.invoke_from, + for_log=True, + ) + credentials = self._runtime_support.build_credentials(parameters=parameters) + + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + + try: + message_stream = strategy.invoke( + params=parameters, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, + ) + except Exception as e: + error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + error=str(error), + ) + ) + return + + try: + yield from self._message_transformer.transform( + messages=message_stream, + tool_info={ + "icon": self._presentation_provider.get_icon( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + ), + "agent_strategy": self.node_data.agent_strategy_name, + }, + parameters_for_log=parameters_for_log, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, + node_type=self.node_type, + node_id=self._node_id, + node_execution_id=self.id, + ) + except PluginDaemonClientSideError as e: + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + error=str(transform_error), + ) + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AgentNodeData, + ) -> Mapping[str, Sequence[str]]: + _ = graph_config # Explicitly mark as unused + result: dict[str, Any] = {} + typed_node_data = node_data + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + + result = {node_id + "." + key: value for key, value in result.items()} + + return result diff --git a/api/dify_graph/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py similarity index 93% rename from api/dify_graph/nodes/agent/entities.py rename to api/core/workflow/nodes/agent/entities.py index f7b7af8fa4..59842862ef 100644 --- a/api/dify_graph/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -11,9 +11,9 @@ from dify_graph.enums import NodeType class AgentNodeData(BaseNodeData): type: NodeType = NodeType.AGENT - agent_strategy_provider_name: str # redundancy + agent_strategy_provider_name: str agent_strategy_name: str - agent_strategy_label: str # redundancy + agent_strategy_label: str memory: MemoryConfig | None = None # The version of the tool parameter. # If this value is None, it indicates this is a previous version diff --git a/api/dify_graph/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exceptions.py similarity index 90% rename from api/dify_graph/nodes/agent/exc.py rename to api/core/workflow/nodes/agent/exceptions.py index ba2c83d8a6..944f5f0b20 100644 --- a/api/dify_graph/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exceptions.py @@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError): self.expected_type = expected_type self.actual_type = actual_type super().__init__(message) - - -class AgentMaxIterationError(AgentNodeError): - """Exception raised when the agent exceeds the maximum iteration limit.""" - - def __init__(self, max_iteration: int): - self.max_iteration = max_iteration - super().__init__( - f"Agent exceeded the maximum iteration limit of {max_iteration}. " - f"The agent was unable to complete the task within the allowed number of iterations." - ) diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py new file mode 100644 index 0000000000..317db14d3f --- /dev/null +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ( + AgentLogEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) +from dify_graph.variables.segments import ArrayFileSegment +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError + + +class AgentMessageTransformer: + def transform( + self, + *, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + ) -> Generator[NodeEventBase, None, None]: + from core.plugin.impl.plugin import PluginInstaller + + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json_list: list[dict | list] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage = LLMUsage.empty_usage() + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=message.message.text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == NodeType.AGENT: + if isinstance(message.message.json_object, dict): + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + else: + llm_usage = LLMUsage.empty_usage() + agent_execution_metadata = {} + if message.message.json_object: + json_list.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=stream_text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, dict) + if "file" not in message.meta: + raise AgentNodeError("File message is missing 'file' key in meta") + + if not isinstance(message.meta["file"], File): + raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + message_id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + for log in agent_logs: + if log.message_id == agent_log.message_id: + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + json_output: list[dict[str, Any] | list[Any]] = [] + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.message_id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + if json_list: + json_output.extend(json_list) + else: + json_output.append({"data": []}) + + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk="", + is_final=True, + ) + + for var_name in variables: + yield StreamChunkEvent( + selector=[node_id, var_name], + chunk="", + is_final=True, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "text": text, + "usage": jsonable_encoder(llm_usage), + "files": ArrayFileSegment(value=files), + "json": json_output, + **variables, + }, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/core/workflow/nodes/agent/plugin_strategy_adapter.py b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py new file mode 100644 index 0000000000..1fc427ad6c --- /dev/null +++ b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from factories.agent_factory import get_plugin_agent_strategy + +from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy + + +class PluginAgentStrategyResolver(AgentStrategyResolver): + def resolve( + self, + *, + tenant_id: str, + agent_strategy_provider_name: str, + agent_strategy_name: str, + ) -> ResolvedAgentStrategy: + return get_plugin_agent_strategy( + tenant_id=tenant_id, + agent_strategy_provider_name=agent_strategy_provider_name, + agent_strategy_name=agent_strategy_name, + ) + + +class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider): + def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: + from core.plugin.impl.plugin import PluginInstaller + + manager = PluginInstaller() + try: + plugins = manager.list_plugins(tenant_id) + except Exception: + return None + + try: + current_plugin = next( + plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name + ) + except StopIteration: + return None + + return current_plugin.declaration.icon diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py new file mode 100644 index 0000000000..2ff7c964b9 --- /dev/null +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import json +from collections.abc import Sequence +from typing import Any, cast + +from packaging.version import Version +from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.agent.entities import AgentToolEntity +from core.agent.plugin_entities import AgentStrategyParameter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.plugin.entities.request import InvokeCredentials +from core.provider_manager import ProviderManager +from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType +from core.tools.tool_manager import ToolManager +from dify_graph.enums import SystemVariableKey +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import StringSegment +from extensions.ext_database import db +from models.model import Conversation + +from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated +from .exceptions import AgentInputTypeError, AgentVariableNotFoundError +from .strategy_protocols import ResolvedAgentStrategy + + +class AgentRuntimeSupport: + def build_parameters( + self, + *, + agent_parameters: Sequence[AgentStrategyParameter], + variable_pool: VariablePool, + node_data: AgentNodeData, + strategy: ResolvedAgentStrategy, + tenant_id: str, + app_id: str, + invoke_from: Any, + for_log: bool = False, + ) -> dict[str, Any]: + agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.agent_parameters: + parameter = agent_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + + agent_input = node_data.agent_parameters[parameter_name] + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore[arg-type] + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: + parameter_value = str(agent_input.value) + + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) + + value = parameter_value + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + value = [tool for tool in value if tool.get("enabled", False)] + value = self._filter_mcp_type_tool(strategy, value) + for tool in value: + if "schemas" in tool: + tool.pop("schemas") + parameters = tool.get("parameters", {}) + if all(isinstance(v, dict) for _, v in parameters.items()): + params = {} + for key, param in parameters.items(): + if param.get("auto", ParamsAutoGenerated.OPEN) in ( + ParamsAutoGenerated.CLOSE, + 0, + ): + value_param = param.get("value", {}) + if value_param and value_param.get("type", "") == "variable": + variable_selector = value_param.get("value") + if not variable_selector: + raise ValueError("Variable selector is missing for a variable-type parameter.") + + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + + params[key] = variable.value + else: + params[key] = value_param.get("value", "") if value_param is not None else None + else: + params[key] = None + parameters = params + tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} + tool["parameters"] = parameters + + if not for_log: + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + tool_value = [] + for tool in value: + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + setting_params = tool.get("settings", {}) + parameters = tool.get("parameters", {}) + manual_input_params = [key for key, value in parameters.items() if value is not None] + + parameters = {**parameters, **setting_params} + entity = AgentToolEntity( + provider_id=tool.get("provider_name", ""), + provider_type=provider_type, + tool_name=tool.get("tool_name", ""), + tool_parameters=parameters, + plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), + ) + + extra = tool.get("extra", {}) + + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version is not None: + runtime_variable_pool = variable_pool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id, + app_id, + entity, + invoke_from, + runtime_variable_pool, + ) + if tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + extra.get("description", "") or tool_runtime.entity.description.llm + ) + for tool_runtime_params in tool_runtime.entity.parameters: + tool_runtime_params.form = ( + ToolParameter.ToolParameterForm.FORM + if tool_runtime_params.name in manual_input_params + else tool_runtime_params.form + ) + manual_input_value = {} + if tool_runtime.entity.parameters: + manual_input_value = { + key: value for key, value in parameters.items() if key in manual_input_params + } + runtime_parameters = { + **tool_runtime.runtime.runtime_parameters, + **manual_input_value, + } + tool_value.append( + { + **tool_runtime.entity.model_dump(mode="json"), + "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), + "provider_type": provider_type.value, + } + ) + value = tool_value + if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: + value = cast(dict[str, Any], value) + model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + history_prompt_messages = [] + if node_data.memory: + memory = self.fetch_memory( + variable_pool=variable_pool, + app_id=app_id, + model_instance=model_instance, + ) + if memory: + prompt_messages = memory.get_history_prompt_messages( + message_limit=node_data.memory.window.size or None + ) + history_prompt_messages = [ + prompt_message.model_dump(mode="json") for prompt_message in prompt_messages + ] + value["history_prompt_messages"] = history_prompt_messages + if model_schema: + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None + result[parameter_name] = value + + return result + + def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials: + credentials = InvokeCredentials() + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if not tool.get("credential_id"): + continue + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + except ValidationError: + continue + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + return credentials + + def fetch_memory( + self, + *, + variable_pool: VariablePool, + app_id: str, + model_instance: ModelInstance, + ) -> TokenBufferMemory | None: + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=tenant_id, + provider=value.get("provider", ""), + model_type=ModelType.LLM, + ) + model_name = value.get("model", "") + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, + model=model_name, + ) + provider_name = provider_model_bundle.configuration.provider.provider + model_type_instance = provider_model_bundle.model_type_instance + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + provider=provider_name, + model_type=ModelType(value.get("model_type", "")), + model=model_name, + ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_instance, model_schema + + @staticmethod + def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features[:]: + try: + AgentOldVersionModelFeatures(feature.value) + except ValueError: + model_schema.features.remove(feature) + return model_schema + + @staticmethod + def _filter_mcp_type_tool( + strategy: ResolvedAgentStrategy, + tools: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + meta_version = strategy.meta_version + if meta_version and Version(meta_version) > Version("0.0.1"): + return tools + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] diff --git a/api/core/workflow/nodes/agent/strategy_protocols.py b/api/core/workflow/nodes/agent/strategy_protocols.py new file mode 100644 index 0000000000..643d916d15 --- /dev/null +++ b/api/core/workflow/nodes/agent/strategy_protocols.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Generator, Sequence +from typing import Any, Protocol + +from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class ResolvedAgentStrategy(Protocol): + meta_version: str | None + + def get_parameters(self) -> Sequence[AgentStrategyParameter]: ... + + def invoke( + self, + *, + params: dict[str, Any], + user_id: str, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: ... + + +class AgentStrategyResolver(Protocol): + def resolve( + self, + *, + tenant_id: str, + agent_strategy_provider_name: str, + agent_strategy_name: str, + ) -> ResolvedAgentStrategy: ... + + +class AgentStrategyPresentationProvider(Protocol): + def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ... diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index fef01049c5..01b309bf54 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_resolution import resolve_workflow_node_class from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -23,7 +24,6 @@ from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from dify_graph.nodes import NodeType from dify_graph.nodes.base.node import Node -from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool @@ -343,7 +343,7 @@ class WorkflowEntry: if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: raise ValueError(f"Node type {node_type} not supported") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"] + node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1") if not node_cls: raise ValueError(f"Node class not found for node type {node_type}") diff --git a/api/dify_graph/entities/__init__.py b/api/dify_graph/entities/__init__.py index e73c38c1d3..ef7789c49c 100644 --- a/api/dify_graph/entities/__init__.py +++ b/api/dify_graph/entities/__init__.py @@ -1,11 +1,9 @@ -from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution from .workflow_start_reason import WorkflowStartReason __all__ = [ - "AgentNodeStrategyInit", "GraphInitParams", "WorkflowExecution", "WorkflowNodeExecution", diff --git a/api/dify_graph/entities/agent.py b/api/dify_graph/entities/agent.py deleted file mode 100644 index 2b4d6db76f..0000000000 --- a/api/dify_graph/entities/agent.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class AgentNodeStrategyInit(BaseModel): - """Agent node strategy initialization data.""" - - name: str - icon: str | None = None diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py index 21ddf80b64..8552254627 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -4,7 +4,6 @@ from datetime import datetime from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities import AgentNodeStrategyInit from dify_graph.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -13,8 +12,8 @@ from .base import GraphNodeEventBase class NodeRunStartedEvent(GraphNodeEventBase): node_title: str predecessor_node_id: str | None = None - agent_strategy: AgentNodeStrategyInit | None = None start_at: datetime = Field(..., description="node start time") + extras: dict[str, object] = Field(default_factory=dict) # FIXME(-LAN-): only for ToolNode provider_type: str = "" diff --git a/api/dify_graph/nodes/agent/__init__.py b/api/dify_graph/nodes/agent/__init__.py deleted file mode 100644 index 95e7cf895b..0000000000 --- a/api/dify_graph/nodes/agent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .agent_node import AgentNode - -__all__ = ["AgentNode"] diff --git a/api/dify_graph/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py deleted file mode 100644 index d501217454..0000000000 --- a/api/dify_graph/nodes/agent/agent_node.py +++ /dev/null @@ -1,761 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from packaging.version import Version -from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.agent.entities import AgentToolEntity -from core.agent.plugin_entities import AgentStrategyParameter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager -from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ( - ToolIdentity, - ToolInvokeMessage, - ToolParameter, - ToolProviderType, -) -from core.tools.tool_manager import ToolManager -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import ( - NodeType, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( - AgentLogEvent, - NodeEventBase, - NodeRunResult, - StreamChunkEvent, - StreamCompletedEvent, -) -from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayFileSegment, StringSegment -from extensions.ext_database import db -from factories import file_factory -from factories.agent_factory import get_plugin_agent_strategy -from models import ToolFile -from models.model import Conversation -from services.tools.builtin_tools_manage_service import BuiltinToolManageService - -from .exc import ( - AgentInputTypeError, - AgentInvocationError, - AgentMessageTransformError, - AgentNodeError, - AgentVariableNotFoundError, - AgentVariableTypeError, - ToolFileNotFoundError, -) - -if TYPE_CHECKING: - from core.agent.strategy.plugin import PluginAgentStrategy - from core.plugin.entities.request import InvokeCredentials - - -class AgentNode(Node[AgentNodeData]): - """ - Agent Node - """ - - node_type = NodeType.AGENT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[NodeEventBase, None, None]: - from core.plugin.impl.exc import PluginDaemonClientSideError - - dify_ctx = self.require_dify_context() - - try: - strategy = get_plugin_agent_strategy( - tenant_id=dify_ctx.tenant_id, - agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, - agent_strategy_name=self.node_data.agent_strategy_name, - ) - except Exception as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - error=f"Failed to get agent strategy: {str(e)}", - ), - ) - return - - agent_parameters = strategy.get_parameters() - - # get parameters - parameters = self._generate_agent_parameters( - agent_parameters=agent_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - strategy=strategy, - ) - parameters_for_log = self._generate_agent_parameters( - agent_parameters=agent_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - strategy=strategy, - ) - credentials = self._generate_credentials(parameters=parameters) - - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - - try: - message_stream = strategy.invoke( - params=parameters, - user_id=dify_ctx.user_id, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, - credentials=credentials, - ) - except Exception as e: - error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - error=str(error), - ) - ) - return - - try: - yield from self._transform_message( - messages=message_stream, - tool_info={ - "icon": self.agent_strategy_icon, - "agent_strategy": self.node_data.agent_strategy_name, - }, - parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - node_type=self.node_type, - node_id=self._node_id, - node_execution_id=self.id, - ) - except PluginDaemonClientSideError as e: - transform_error = AgentMessageTransformError( - f"Failed to transform agent message: {str(e)}", original_error=e - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - error=str(transform_error), - ) - ) - - def _generate_agent_parameters( - self, - *, - agent_parameters: Sequence[AgentStrategyParameter], - variable_pool: VariablePool, - node_data: AgentNodeData, - for_log: bool = False, - strategy: PluginAgentStrategy, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - agent_parameters (Sequence[AgentParameter]): The list of agent parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (AgentNodeData): The data associated with the agent node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - parameter = agent_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - agent_input = node_data.agent_parameters[parameter_name] - match agent_input.type: - case "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - case "mixed" | "constant": - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: - parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - case _: - raise AgentInputTypeError(agent_input.type) - value = parameter_value - if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - value = [tool for tool in value if tool.get("enabled", False)] - value = self._filter_mcp_type_tool(strategy, value) - for tool in value: - if "schemas" in tool: - tool.pop("schemas") - parameters = tool.get("parameters", {}) - if all(isinstance(v, dict) for _, v in parameters.items()): - params = {} - for key, param in parameters.items(): - if param.get("auto", ParamsAutoGenerated.OPEN) in ( - ParamsAutoGenerated.CLOSE, - 0, - ): - value_param = param.get("value", {}) - if value_param and value_param.get("type", "") == "variable": - variable_selector = value_param.get("value") - if not variable_selector: - raise ValueError("Variable selector is missing for a variable-type parameter.") - - variable = variable_pool.get(variable_selector) - if variable is None: - raise AgentVariableNotFoundError(str(variable_selector)) - - params[key] = variable.value - else: - params[key] = value_param.get("value", "") if value_param is not None else None - else: - params[key] = None - parameters = params - tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} - tool["parameters"] = parameters - - if not for_log: - if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - tool_value = [] - for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) - setting_params = tool.get("settings", {}) - parameters = tool.get("parameters", {}) - manual_input_params = [key for key, value in parameters.items() if value is not None] - - parameters = {**parameters, **setting_params} - entity = AgentToolEntity( - provider_id=tool.get("provider_name", ""), - provider_type=provider_type, - tool_name=tool.get("tool_name", ""), - tool_parameters=parameters, - plugin_unique_identifier=tool.get("plugin_unique_identifier", None), - credential_id=tool.get("credential_id", None), - ) - - extra = tool.get("extra", {}) - - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - runtime_variable_pool: VariablePool | None = None - if node_data.version != "1" or node_data.tool_node_version is not None: - runtime_variable_pool = variable_pool - dify_ctx = self.require_dify_context() - tool_runtime = ToolManager.get_agent_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - entity, - dify_ctx.invoke_from, - runtime_variable_pool, - ) - if tool_runtime.entity.description: - tool_runtime.entity.description.llm = ( - extra.get("description", "") or tool_runtime.entity.description.llm - ) - for tool_runtime_params in tool_runtime.entity.parameters: - tool_runtime_params.form = ( - ToolParameter.ToolParameterForm.FORM - if tool_runtime_params.name in manual_input_params - else tool_runtime_params.form - ) - manual_input_value = {} - if tool_runtime.entity.parameters: - manual_input_value = { - key: value for key, value in parameters.items() if key in manual_input_params - } - runtime_parameters = { - **tool_runtime.runtime.runtime_parameters, - **manual_input_value, - } - tool_value.append( - { - **tool_runtime.entity.model_dump(mode="json"), - "runtime_parameters": runtime_parameters, - "credential_id": tool.get("credential_id", None), - "provider_type": provider_type.value, - } - ) - value = tool_value - if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: - value = cast(dict[str, Any], value) - model_instance, model_schema = self._fetch_model(value) - # memory config - history_prompt_messages = [] - if node_data.memory: - memory = self._fetch_memory(model_instance) - if memory: - prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size or None - ) - history_prompt_messages = [ - prompt_message.model_dump(mode="json") for prompt_message in prompt_messages - ] - value["history_prompt_messages"] = history_prompt_messages - if model_schema: - # remove structured output feature to support old version agent plugin - model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) - value["entity"] = model_schema.model_dump(mode="json") - else: - value["entity"] = None - result[parameter_name] = value - - return result - - def _generate_credentials( - self, - parameters: dict[str, Any], - ) -> InvokeCredentials: - """ - Generate credentials based on the given agent parameters. - """ - from core.plugin.entities.request import InvokeCredentials - - credentials = InvokeCredentials() - - # generate credentials for tools selector - credentials.tool_credentials = {} - for tool in parameters.get("tools", []): - if tool.get("credential_id"): - try: - identity = ToolIdentity.model_validate(tool.get("identity", {})) - credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) - except ValidationError: - continue - return credentials - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AgentNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - result: dict[str, Any] = {} - typed_node_data = node_data - for parameter_name in typed_node_data.agent_parameters: - input = typed_node_data.agent_parameters[parameter_name] - match input.type: - case "mixed" | "constant": - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - result[parameter_name] = input.value - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def agent_strategy_icon(self) -> str | None: - """ - Get agent strategy icon - :return: - """ - from core.plugin.impl.plugin import PluginInstaller - - manager = PluginInstaller() - dify_ctx = self.require_dify_context() - plugins = manager.list_plugins(dify_ctx.tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name - ) - icon = current_plugin.declaration.icon - except StopIteration: - icon = None - return icon - - def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: - # get conversation id - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - if not isinstance(conversation_id_variable, StringSegment): - return None - conversation_id = conversation_id_variable.value - - dify_ctx = self.require_dify_context() - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where( - Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id - ) - conversation = session.scalar(stmt) - - if not conversation: - return None - - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - - return memory - - def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - dify_ctx = self.require_dify_context() - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM - ) - model_name = value.get("model", "") - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, model=model_name - ) - provider_name = provider_model_bundle.configuration.provider.provider - model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( - tenant_id=dify_ctx.tenant_id, - provider=provider_name, - model_type=ModelType(value.get("model_type", "")), - model=model_name, - ) - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_instance, model_schema - - def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: - if model_schema.features: - for feature in model_schema.features[:]: # Create a copy to safely modify during iteration - try: - AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value - except ValueError: - model_schema.features.remove(feature) - return model_schema - - def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Filter MCP type tool - :param strategy: plugin agent strategy - :param tool: tool - :return: filtered tool dict - """ - meta_version = strategy.meta_version - if meta_version and Version(meta_version) > Version("0.0.1"): - return tools - else: - return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] - - def _transform_message( - self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, - node_type: NodeType, - node_id: str, - node_execution_id: str, - ) -> Generator[NodeEventBase, None, None]: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json_list: list[dict | list] = [] - - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - llm_usage = LLMUsage.empty_usage() - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileNotFoundError(tool_file_id) - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileNotFoundError(tool_file_id) - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if node_type == NodeType.AGENT: - if isinstance(message.message.json_object, dict): - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } - else: - msg_metadata = {} - llm_usage = LLMUsage.empty_usage() - agent_execution_metadata = {} - if message.message.json_object: - json_list.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise AgentVariableTypeError( - "When 'stream' is True, 'variable_value' must be a string.", - variable_name=variable_name, - expected_type="str", - actual_type=type(variable_value).__name__, - ) - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise AgentNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - message_id=message.message.id, - node_execution_id=node_execution_id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.message_id == agent_log.message_id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 1: append each agent log as its own dict. - if agent_logs: - for log in agent_logs: - json_output.append( - { - "id": log.message_id, - "parent_id": log.parent_id, - "error": log.error, - "status": log.status, - "data": log.data, - "label": log.label, - "metadata": log.metadata, - "node_id": log.node_id, - } - ) - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json_list: - json_output.extend(json_list) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[node_id, var_name], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "text": text, - "usage": jsonable_encoder(llm_usage), - "files": ArrayFileSegment(value=files), - "json": json_output, - **variables, - }, - metadata={ - **agent_execution_metadata, - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, - }, - inputs=parameters_for_log, - llm_usage=llm_usage, - ) - ) diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index d7702c8cb4..2044b09333 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -11,7 +11,7 @@ from types import MappingProxyType from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin from uuid import uuid4 -from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams +from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY @@ -349,6 +349,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def populate_start_event(self, event: NodeRunStartedEvent) -> None: + """Allow subclasses to enrich the started event without cross-node imports in the base class.""" + _ = event + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -362,39 +366,10 @@ class Node(Generic[NodeDataT]): in_iteration_id=None, start_at=self._start_at, ) - - # === FIXME(-LAN-): Needs to refactor. - from dify_graph.nodes.tool.tool_node import ToolNode - - if isinstance(self, ToolNode): - start_event.provider_id = getattr(self.node_data, "provider_id", "") - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from dify_graph.nodes.datasource.datasource_node import DatasourceNode - - if isinstance(self, DatasourceNode): - plugin_id = getattr(self.node_data, "plugin_id", "") - provider_name = getattr(self.node_data, "provider_name", "") - - start_event.provider_id = f"{plugin_id}/{provider_name}" - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode - - if isinstance(self, TriggerEventNode): - start_event.provider_id = getattr(self.node_data, "provider_id", "") - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from dify_graph.nodes.agent.agent_node import AgentNode - from dify_graph.nodes.agent.entities import AgentNodeData - - if isinstance(self, AgentNode): - start_event.agent_strategy = AgentNodeStrategyInit( - name=cast(AgentNodeData, self.node_data).agent_strategy_name, - icon=self.agent_strategy_icon, - ) - - # === + try: + self.populate_start_event(start_event) + except Exception: + logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) yield start_event try: @@ -513,10 +488,8 @@ class Node(Generic[NodeDataT]): @abstractmethod def version(cls) -> str: """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. - # - # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` - # in `api/dify_graph/nodes/__init__.py`. + # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so + # `Node.get_node_type_classes_mapping()` can resolve numeric versions and `latest`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod @@ -524,7 +497,9 @@ class Node(Generic[NodeDataT]): """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. Import all modules under dify_graph.nodes so subclasses register themselves on import. - Then we return a readonly view of the registry to avoid accidental mutation. + Callers that rely on workflow-local nodes defined outside `dify_graph.nodes` must import + those modules before invoking this method so they can register through `__init_subclass__`. + We then return a readonly view of the registry to avoid accidental mutation. """ # Import all node modules to ensure they are loaded (thus registered) import dify_graph.nodes as _nodes_pkg diff --git a/api/dify_graph/nodes/datasource/datasource_node.py b/api/dify_graph/nodes/datasource/datasource_node.py index b0d3e1a24e..62dcb2924f 100644 --- a/api/dify_graph/nodes/datasource/datasource_node.py +++ b/api/dify_graph/nodes/datasource/datasource_node.py @@ -48,6 +48,10 @@ class DatasourceNode(Node[DatasourceNodeData]): ) self.datasource_manager = datasource_manager + def populate_start_event(self, event) -> None: + event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}" + event.provider_type = self.node_data.provider_type + def _run(self) -> Generator: """ Run the datasource node diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index bc36f635e9..1d626f4bd6 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -486,14 +486,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # variable selector to variable mapping try: # Get node class - from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + from dify_graph.nodes.node_mapping import get_node_type_classes_mapping typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) node_type = typed_sub_node_config["data"].type - if node_type not in NODE_TYPE_CLASSES_MAPPING: + node_mapping = get_node_type_classes_mapping() + if node_type not in node_mapping: continue node_version = str(typed_sub_node_config["data"].version) - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_cls = node_mapping[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=typed_sub_node_config diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 27b78eaa3f..1a8774f445 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -316,14 +316,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # variable selector to variable mapping try: # Get node class - from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + from dify_graph.nodes.node_mapping import get_node_type_classes_mapping typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) node_type = typed_sub_node_config["data"].type - if node_type not in NODE_TYPE_CLASSES_MAPPING: + node_mapping = get_node_type_classes_mapping() + if node_type not in node_mapping: continue node_version = str(typed_sub_node_config["data"].version) - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_cls = node_mapping[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=typed_sub_node_config diff --git a/api/dify_graph/nodes/node_mapping.py b/api/dify_graph/nodes/node_mapping.py index 8e5405f1aa..e0f5524a04 100644 --- a/api/dify_graph/nodes/node_mapping.py +++ b/api/dify_graph/nodes/node_mapping.py @@ -5,5 +5,24 @@ from dify_graph.nodes.base.node import Node LATEST_VERSION = "latest" -# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes -NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + +def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: + """Return the live node registry after importing all `dify_graph.nodes` modules.""" + return Node.get_node_type_classes_mapping() + + +def resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + node_mapping = get_node_type_classes_mapping().get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class + + +# Snapshot kept for compatibility with older tests; production paths should use the live helpers. +NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 44d0ca885d..ec7386981e 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -65,6 +65,10 @@ class ToolNode(Node[ToolNodeData]): def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + event.provider_id = self.node_data.provider_id + event.provider_type = self.node_data.provider_type + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node diff --git a/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py b/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py index b4f1116f7e..536ba96dec 100644 --- a/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py +++ b/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py @@ -32,6 +32,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]): def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + event.provider_id = self.node_data.provider_id + def _run(self) -> NodeRunResult: """ Run the plugin trigger node. diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 7ee4638e77..a94d75ec76 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -17,7 +17,8 @@ from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from models.workflow import WorkflowNodeExecutionModel +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository logger = logging.getLogger(__name__) @@ -47,12 +48,28 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.triggered_from = data.get("triggered_from") or "" + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowNodeExecutionTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to WORKFLOW_RUN", triggered_from_val) + model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" model.status = data.get("status") or "running" # Default status if missing model.title = data.get("title") or "" - model.created_by_role = data.get("created_by_role") or "" + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.index = safe_int(data.get("index", 0)) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 14382ed876..bdfc81bd1c 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -22,12 +22,13 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker +from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.types import ( AverageInteractionStats, @@ -59,11 +60,37 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.type = data.get("type") or "" - model.triggered_from = data.get("triggered_from") or "" + type_val = data.get("type") + try: + model.type = WorkflowType(str(type_val)) if type_val else WorkflowType.WORKFLOW + except ValueError: + logger.warning("Invalid type value: %s, falling back to WORKFLOW", type_val) + model.type = WorkflowType.WORKFLOW + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowRunTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowRunTriggeredFrom.APP_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to APP_RUN", triggered_from_val) + model.triggered_from = WorkflowRunTriggeredFrom.APP_RUN model.version = data.get("version") or "" - model.status = data.get("status") or "running" # Default status if missing - model.created_by_role = data.get("created_by_role") or "" + status_val = data.get("status") + try: + model.status = WorkflowExecutionStatus(str(status_val)) if status_val else WorkflowExecutionStatus.RUNNING + except ValueError: + logger.warning("Invalid status value: %s, falling back to RUNNING", status_val) + model.status = WorkflowExecutionStatus.RUNNING + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.total_tokens = safe_int(data.get("total_tokens", 0)) diff --git a/api/models/account.py b/api/models/account.py index f7a9c20026..1a43c9ca17 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,12 +8,12 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, validates +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class TenantAccountRole(enum.StrEnum): @@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase): last_active_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False ) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active") + status: Mapped[AccountStatus] = mapped_column( + EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE + ) initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) - @validates("status") - def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: - if isinstance(value, AccountStatus): - return value.value - return value - @property def is_password_set(self): return self.password is not None @@ -177,8 +173,7 @@ class Account(UserMixin, TypeBase): return self.role def get_status(self) -> AccountStatus: - status_str = self.status - return AccountStatus(status_str) + return self.status @classmethod def get_by_openid(cls, provider: str, open_id: str): @@ -249,7 +244,9 @@ class Tenant(TypeBase): name: Mapped[str] = mapped_column(String(255)) encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic") - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal") + status: Mapped[TenantStatus] = mapped_column( + EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL + ) custom_config: Mapped[str | None] = mapped_column(LongText, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -291,7 +288,9 @@ class TenantAccountJoin(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) - role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + role: Mapped[TenantAccountRole] = mapped_column( + EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL + ) invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -324,6 +323,11 @@ class AccountIntegrate(TypeBase): ) +class InvitationCodeStatus(enum.StrEnum): + UNUSED = "unused" + USED = "used" + + class InvitationCode(TypeBase): __tablename__ = "invitation_codes" __table_args__ = ( @@ -335,7 +339,11 @@ class InvitationCode(TypeBase): id: Mapped[int] = mapped_column(sa.Integer, init=False) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'"), default="unused") + status: Mapped[InvitationCodeStatus] = mapped_column( + EnumText(InvitationCodeStatus, length=16), + server_default=sa.text("'unused'"), + default=InvitationCodeStatus.UNUSED, + ) used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None) used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None) used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None) @@ -367,10 +375,13 @@ class TenantPluginPermission(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) install_permission: Mapped[InstallPermission] = mapped_column( - String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE + EnumText(InstallPermission, length=16), + nullable=False, + server_default="everyone", + default=InstallPermission.EVERYONE, ) debug_permission: Mapped[DebugPermission] = mapped_column( - String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY + EnumText(DebugPermission, length=16), nullable=False, server_default="noone", default=DebugPermission.NOBODY ) @@ -396,10 +407,13 @@ class TenantPluginAutoUpgradeStrategy(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) strategy_setting: Mapped[StrategySetting] = mapped_column( - String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY + EnumText(StrategySetting, length=16), + nullable=False, + server_default="fix_only", + default=StrategySetting.FIX_ONLY, ) upgrade_mode: Mapped[UpgradeMode] = mapped_column( - String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE + EnumText(UpgradeMode, length=16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE ) exclude_plugins: Mapped[list[str]] = mapped_column(sa.JSON, nullable=False, default_factory=list) include_plugins: Mapped[list[str]] = mapped_column(sa.JSON, nullable=False, default_factory=list) diff --git a/api/models/dataset.py b/api/models/dataset.py index 4ef39fcde1..b3fa11a58c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -30,8 +30,9 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .base import Base, TypeBase from .engine import db +from .enums import CreatorUserRole from .model import App, Tag, TagBinding, UploadFile -from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index +from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index logger = logging.getLogger(__name__) @@ -59,7 +60,11 @@ class Dataset(Base): name: Mapped[str] = mapped_column(String(255)) description = mapped_column(LongText, nullable=True) provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'")) - permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'")) + permission: Mapped[DatasetPermissionEnum] = mapped_column( + EnumText(DatasetPermissionEnum, length=255), + server_default=sa.text("'only_me'"), + default=DatasetPermissionEnum.ONLY_ME, + ) data_source_type = mapped_column(String(255)) indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(LongText, nullable=True) @@ -1003,7 +1008,7 @@ class DatasetQuery(TypeBase): content: Mapped[str] = mapped_column(LongText, nullable=False) source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False diff --git a/api/models/enums.py b/api/models/enums.py index ed6236209f..66e3e4b332 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -72,3 +72,23 @@ class AppTriggerType(StrEnum): # for backward compatibility UNKNOWN = "unknown" + + +class AppStatus(StrEnum): + """App Status Enum""" + + NORMAL = "normal" + + +class AppMCPServerStatus(StrEnum): + """AppMCPServer Status Enum""" + + NORMAL = "normal" + ACTIVE = "active" + INACTIVE = "inactive" + + +class ConversationStatus(StrEnum): + """Conversation Status Enum""" + + NORMAL = "normal" diff --git a/api/models/model.py b/api/models/model.py index ed0614c195..2e747df2c7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,9 +29,9 @@ from libs.uuid_utils import uuidv7 from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db -from .enums import CreatorUserRole +from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus from .provider_ids import GenericProviderID -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from .workflow import Workflow @@ -337,13 +337,15 @@ class App(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) - mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255)) icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'")) + status: Mapped[AppStatus] = mapped_column( + EnumText(AppStatus, length=255), server_default=sa.text("'normal'"), default=AppStatus.NORMAL + ) enable_site: Mapped[bool] = mapped_column(sa.Boolean) enable_api: Mapped[bool] = mapped_column(sa.Boolean) api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) @@ -1000,14 +1002,16 @@ class Conversation(Base): model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(LongText) model_id = mapped_column(String(255), nullable=True) - mode: Mapped[str] = mapped_column(String(255)) + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(LongText) _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) introduction = mapped_column(LongText) system_instruction = mapped_column(LongText) system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - status: Mapped[str] = mapped_column(String(255), nullable=False) + status: Mapped[ConversationStatus] = mapped_column( + EnumText(ConversationStatus, length=255), nullable=False, default=ConversationStatus.NORMAL + ) # The `invoke_from` records how the conversation is created. # @@ -1351,7 +1355,12 @@ class Message(Base): provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7)) currency: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[MessageStatus] = mapped_column( + EnumText(MessageStatus, length=255), + nullable=False, + server_default=sa.text("'normal'"), + default=MessageStatus.NORMAL, + ) error: Mapped[str | None] = mapped_column(LongText) message_metadata: Mapped[str | None] = mapped_column(LongText) invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) @@ -1364,7 +1373,7 @@ class Message(Base): ) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) - app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True) + app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True) @property def inputs(self) -> dict[str, Any]: @@ -1766,8 +1775,10 @@ class MessageFile(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) - transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[FileTransferMethod] = mapped_column( + EnumText(FileTransferMethod, length=255), nullable=False + ) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) @@ -1976,7 +1987,9 @@ class AppMCPServer(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=False) server_code: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[AppMCPServerStatus] = mapped_column( + EnumText(AppMCPServerStatus, length=255), nullable=False, server_default=sa.text("'normal'") + ) parameters: Mapped[str] = mapped_column(LongText, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -2015,7 +2028,7 @@ class Site(Base): id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - icon_type = mapped_column(String(255), nullable=True) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True) icon = mapped_column(String(255)) icon_background = mapped_column(String(255)) description = mapped_column(LongText) @@ -2030,7 +2043,9 @@ class Site(Base): customize_domain = mapped_column(String(255)) customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[AppStatus] = mapped_column( + EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL + ) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -2110,7 +2125,12 @@ class UploadFile(Base): # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # The `created_by` field stores the ID of the entity that created this upload file. # @@ -2163,7 +2183,7 @@ class UploadFile(Base): self.size = size self.extension = extension self.mime_type = mime_type - self.created_by_role = created_by_role.value + self.created_by_role = created_by_role self.created_by = created_by self.created_at = created_at self.used = used @@ -2226,7 +2246,7 @@ class MessageAgentThought(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) diff --git a/api/models/provider.py b/api/models/provider.py index 6175a3ae88..18a0fe92c8 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class ProviderType(StrEnum): @@ -69,8 +69,8 @@ class Provider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - provider_type: Mapped[str] = mapped_column( - String(40), nullable=False, server_default=text("'custom'"), default="custom" + provider_type: Mapped[ProviderType] = mapped_column( + EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False) last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False) diff --git a/api/models/trigger.py b/api/models/trigger.py index 209345eb84..43d7fc5b24 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -227,7 +227,7 @@ class WorkflowTriggerLog(TypeBase): queue_name: Mapped[str] = mapped_column(String(100), nullable=False) celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(String(255), nullable=False) retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) diff --git a/api/models/web.py b/api/models/web.py index 5f6a7b40bf..a1cc11c375 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -2,13 +2,14 @@ from datetime import datetime from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func +from sqlalchemy import DateTime, func from sqlalchemy.orm import Mapped, mapped_column from .base import TypeBase from .engine import db +from .enums import CreatorUserRole from .model import Message -from .types import StringUUID +from .types import EnumText, StringUUID class SavedMessage(TypeBase): @@ -24,7 +25,9 @@ class SavedMessage(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'") + ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, @@ -50,8 +53,8 @@ class PinnedConversation(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role: Mapped[str] = mapped_column( - String(255), + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'"), ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 21b899eeda..8c62292079 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -53,7 +53,7 @@ from libs import helper from .account import Account from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db -from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType +from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) @@ -141,7 +141,7 @@ class Workflow(Base): # bug id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False) version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="") marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="") @@ -188,7 +188,7 @@ class Workflow(Base): # bug workflow.id = str(uuid4()) workflow.tenant_id = tenant_id workflow.app_id = app_id - workflow.type = type + workflow.type = WorkflowType(type) workflow.version = version workflow.graph = graph workflow.features = features @@ -608,8 +608,8 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(String(255)) - triggered_from: Mapped[str] = mapped_column(String(255)) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255)) + triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255)) version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) @@ -830,7 +830,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column( + EnumText(WorkflowNodeExecutionTriggeredFrom, length=255) + ) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) @@ -846,7 +848,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(String(255)) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -1130,7 +1132,7 @@ class WorkflowAppLog(TypeBase): workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from: Mapped[str] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1204,7 +1206,7 @@ class WorkflowArchiveLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) @@ -1213,7 +1215,9 @@ class WorkflowArchiveLog(TypeBase): run_version: Mapped[str] = mapped_column(String(255), nullable=False) run_status: Mapped[str] = mapped_column(String(255), nullable=False) - run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( + EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False + ) run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) diff --git a/api/pyproject.toml b/api/pyproject.toml index b3202b49c4..64df4d1e77 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -65,7 +65,7 @@ dependencies = [ "pydantic~=2.12.5", "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", - "pyjwt~=2.11.0", + "pyjwt~=2.12.0", "pypdfium2==5.2.0", "python-docx~=1.2.0", "python-dotenv==1.0.1", diff --git a/api/services/account_service.py b/api/services/account_service.py index f0eac2a522..bd520f54cf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1089,9 +1089,9 @@ class TenantService: ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: - ta.role = role + ta.role = TenantAccountRole(role) else: - ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role)) db.session.add(ta) db.session.commit() @@ -1319,10 +1319,10 @@ class TenantService: db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() ) if current_owner_join: - current_owner_join.role = "admin" + current_owner_join.role = TenantAccountRole.ADMIN # Update the role of the target member - target_member_join.role = new_role + target_member_join.role = TenantAccountRole(new_role) db.session.commit() @staticmethod diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 06f4ccb90e..49ca273442 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -429,17 +429,18 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") + resolved_icon_type: IconType if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]: - icon_type = icon_type_value + resolved_icon_type = IconType(icon_type_value) else: - icon_type = IconType.EMOJI + resolved_icon_type = IconType.EMOJI icon = icon or str(app_data.get("icon", "")) if app: # Update existing app app.name = name or app_data.get("name", app.name) app.description = description or app_data.get("description", app.description) - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id @@ -452,10 +453,10 @@ class AppDslService: app = App() app.id = str(uuid4()) app.tenant_id = account.current_tenant_id - app.mode = app_mode.value + app.mode = app_mode app.name = name or app_data.get("name", "") app.description = description or app_data.get("description", "") - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") app.enable_site = True @@ -549,7 +550,7 @@ class AppDslService: "kind": "app", "app": { "name": app_model.name, - "mode": app_model.mode, + "mode": app_model.mode.value if isinstance(app_model.mode, AppMode) else app_model.mode, "icon": app_model.icon if app_model.icon_type == "image" else "🤖", "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, "description": app_model.description, diff --git a/api/services/app_service.py b/api/services/app_service.py index aba8954f1a..b5e893c5b5 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -19,7 +19,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account -from models.model import App, AppMode, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig, IconType, Site from models.tools import ApiToolProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -254,7 +254,7 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = args["icon_type"] + app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3a7d483a9d..c527c71d7b 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -254,7 +254,7 @@ class DatasetService: dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model_name if embedding_model else None dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None - dataset.permission = permission or DatasetPermissionEnum.ONLY_ME + dataset.permission = DatasetPermissionEnum(permission) if permission else DatasetPermissionEnum.ONLY_ME dataset.provider = provider if summary_index_setting is not None: dataset.summary_index_setting = summary_index_setting diff --git a/api/services/file_service.py b/api/services/file_service.py index e08b78bf4c..ecb30faaa8 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -58,8 +58,9 @@ class FileService: # get file extension extension = os.path.splitext(filename)[1].lstrip(".").lower() - # check if filename contains invalid characters - if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]): + # Only reject path separators here. The original filename is stored as metadata, + # while the storage key is UUID-based. + if any(c in filename for c in ["/", "\\"]): raise ValueError("Filename contains invalid characters") if len(filename) > 200: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index c00c76a826..d85b290534 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -13,6 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery +from models.enums import CreatorUserRole logger = logging.getLogger(__name__) @@ -98,7 +99,7 @@ class HitTestingService: content=json.dumps(dataset_queries), source="hit_testing", source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) db.session.add(dataset_query) @@ -138,7 +139,7 @@ class HitTestingService: content=query, source="hit_testing", source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index ce745a4679..b9a565ec17 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,6 +36,7 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, @@ -48,7 +49,6 @@ from dify_graph.graph_events.base import GraphNodeEventBase from dify_graph.node_events.base import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from dify_graph.runtime import VariablePool from dify_graph.system_variable import SystemVariable @@ -381,7 +381,7 @@ class RagPipelineService: """ # return default block config default_block_configs: list[dict[str, Any]] = [] - for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items(): + for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None if node_type is NodeType.HTTP_REQUEST: @@ -410,12 +410,13 @@ class RagPipelineService: :return: """ node_type_enum = NodeType(node_type) + node_mapping = get_workflow_node_type_classes_mapping() # return default block config - if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + if node_type_enum not in node_mapping: return None - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + node_class = node_mapping[node_type_enum][LATEST_VERSION] final_filters = dict(filters) if filters else {} if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters: final_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py index d4a6e87585..64dad7ba52 100644 --- a/api/services/retention/workflow_run/restore_archived_workflow_run.py +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -358,21 +358,19 @@ class WorkflowRunRestore: self, model: type[DeclarativeBase] | Any, ) -> tuple[set[str], set[str], set[str]]: - columns = list(model.__table__.columns) + table = model.__table__ + columns = list(table.columns) + autoincrement_column = getattr(table, "autoincrement_column", None) + + def has_insert_default(column: Any) -> bool: + # SQLAlchemy may set column.autoincrement to "auto" on non-PK columns. + # Only treat the resolved autoincrement column as DB-generated. + return column.default is not None or column.server_default is not None or column is autoincrement_column + column_names = {column.key for column in columns} - required_columns = { - column.key - for column in columns - if not column.nullable - and column.default is None - and column.server_default is None - and not column.autoincrement - } + required_columns = {column.key for column in columns if not column.nullable and not has_insert_default(column)} non_nullable_with_default = { - column.key - for column in columns - if not column.nullable - and (column.default is not None or column.server_default is not None or column.autoincrement) + column.key for column in columns if not column.nullable and has_insert_default(column) } return column_names, required_columns, non_nullable_with_default diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4dd6c8107b..d0f4f27968 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -3,6 +3,7 @@ from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import SavedMessage from services.message_service import MessageService @@ -54,7 +55,7 @@ class SavedMessageService: saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 560aec2330..e028e3e5e3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import PinnedConversation from services.conversation_service import ConversationService @@ -84,7 +85,7 @@ class WebConversationService: pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 0153046acc..3acbc93678 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -24,7 +24,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, IconType from models.workflow import Workflow, WorkflowType @@ -72,7 +72,7 @@ class WorkflowConverter: new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW - new_app.icon_type = icon_type or app_model.icon_type + new_app.icon_type = IconType(icon_type) if icon_type else app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2549cdf180..5b24c356c2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,6 +14,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.graph_config import NodeConfigDict @@ -34,7 +35,6 @@ from dify_graph.nodes.human_input.entities import ( ) from dify_graph.nodes.human_input.enums import HumanInputFormKind from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from dify_graph.nodes.start.entities import StartNodeData from dify_graph.repositories.human_input_form_repository import FormCreateParams from dify_graph.runtime import GraphRuntimeState, VariablePool @@ -619,7 +619,7 @@ class WorkflowService: """ # return default block config default_block_configs: list[Mapping[str, object]] = [] - for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items(): + for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None if node_type is NodeType.HTTP_REQUEST: @@ -650,12 +650,13 @@ class WorkflowService: :return: """ node_type_enum = NodeType(node_type) + node_mapping = get_workflow_node_type_classes_mapping() # return default block config - if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + if node_type_enum not in node_mapping: return {} - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + node_class = node_mapping[node_type_enum][LATEST_VERSION] resolved_filters = dict(filters) if filters else {} if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters: resolved_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index d06b8c980b..e7f4e37c75 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -164,7 +164,7 @@ def _record_trigger_failure_log( elapsed_time=0.0, total_tokens=0, total_steps=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, created_at=now, finished_at=now, @@ -179,7 +179,7 @@ def _record_trigger_failure_log( workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, ) session.add(workflow_app_log) @@ -212,7 +212,7 @@ def _record_trigger_failure_log( error=error_message, queue_name=queue_name, retry_count=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, triggered_at=now, finished_at=now, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index db8721e90b..f41118e592 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -94,13 +94,15 @@ def _create_workflow_run_from_execution( workflow_run.tenant_id = tenant_id workflow_run.app_id = app_id workflow_run.workflow_id = execution.workflow_id - workflow_run.type = execution.workflow_type.value - workflow_run.triggered_from = triggered_from.value + from models.workflow import WorkflowType as ModelWorkflowType + + workflow_run.type = ModelWorkflowType(execution.workflow_type.value) + workflow_run.triggered_from = triggered_from workflow_run.version = execution.workflow_version json_converter = WorkflowRuntimeTypeConverter() workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) @@ -108,7 +110,7 @@ def _create_workflow_run_from_execution( workflow_run.elapsed_time = execution.elapsed_time workflow_run.total_tokens = execution.total_tokens workflow_run.total_steps = execution.total_steps - workflow_run.created_by_role = creator_user_role.value + workflow_run.created_by_role = creator_user_role workflow_run.created_by = creator_user_id workflow_run.created_at = execution.started_at workflow_run.finished_at = execution.finished_at @@ -121,7 +123,7 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo Update a WorkflowRun database model from a WorkflowExecution domain entity. """ json_converter = WorkflowRuntimeTypeConverter() - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 3f607dc55e..eaafbf99e3 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -98,7 +98,7 @@ def _create_node_execution_from_domain( node_execution.tenant_id = tenant_id node_execution.app_id = app_id node_execution.workflow_id = execution.workflow_id - node_execution.triggered_from = triggered_from.value + node_execution.triggered_from = triggered_from node_execution.workflow_run_id = execution.workflow_execution_id node_execution.index = execution.index node_execution.predecessor_node_id = execution.predecessor_node_id @@ -128,7 +128,7 @@ def _create_node_execution_from_domain( node_execution.status = execution.status.value node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time - node_execution.created_by_role = creator_user_role.value + node_execution.created_by_role = creator_user_role node_execution.created_by = creator_user_id node_execution.created_at = execution.created_at node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index 498ac56d5d..afb6938baa 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -165,7 +165,7 @@ class TestChatMessageApiPermissions: agent_thoughts=[], message_files=[], message_metadata_dict={}, - status="success", + status="normal", error="", parent_message_id=None, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 23cb56d2a5..8a4fb8eda4 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,6 +1,6 @@ import time import uuid -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager @@ -87,17 +87,20 @@ def test_tool_variable_invoke(): } ) - ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) + with patch.object( + ToolParameterConfigurationManager, + "decrypt_tool_parameters", + return_value={"format": "%Y-%m-%d %H:%M:%S"}, + ): + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") - node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") - - # execute node - result = node._run() - for item in result: - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs is not None - assert item.node_run_result.outputs.get("text") is not None + # execute node + result = node._run() + for item in result: + if isinstance(item, StreamCompletedEvent): + assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.node_run_result.outputs is not None + assert item.node_run_result.outputs.get("text") is not None def test_tool_mixed_invoke(): @@ -121,12 +124,15 @@ def test_tool_mixed_invoke(): } ) - ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) - - # execute node - result = node._run() - for item in result: - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs is not None - assert item.node_run_result.outputs.get("text") is not None + with patch.object( + ToolParameterConfigurationManager, + "decrypt_tool_parameters", + return_value={"format": "%Y-%m-%d %H:%M:%S"}, + ): + # execute node + result = node._run() + for item in result: + if isinstance(item, StreamCompletedEvent): + assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.node_run_result.outputs is not None + assert item.node_run_result.outputs.get("text") is not None diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index cdf390b327..a60159c66a 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -18,7 +18,7 @@ from faker import Faker from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus @dataclass @@ -47,7 +47,7 @@ class TestTenantIsolatedTaskQueueIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -55,7 +55,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Create tenant tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Create second tenant tenant2 = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant2) db_session_with_containers.commit() @@ -410,7 +410,7 @@ class TestTenantIsolatedTaskQueueCompatibility: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -418,7 +418,7 @@ class TestTenantIsolatedTaskQueueCompatibility: # Create tenant tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 9354a3ac35..cc9596d15f 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "suspended" + tenant.status = "archive" db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5155d50b0e..5b1a4790f5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -2,6 +2,7 @@ import uuid from unittest.mock import ANY, MagicMock, patch import pytest +import sqlalchemy as sa from faker import Faker from sqlalchemy.orm import Session @@ -492,20 +493,20 @@ class TestAppGenerateService: ) # Manually set invalid mode after creation + # With EnumText, invalid values are rejected at the DB level during flush, + # raising StatementError wrapping ValueError app.mode = "invalid_mode" # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} - # Execute the method under test and expect ValueError - with pytest.raises(ValueError) as exc_info: + # Execute the method under test and expect either ValueError (direct) or + # StatementError (from EnumText validation during autoflush) + with pytest.raises((ValueError, sa.exc.StatementError)): AppGenerateService.generate( app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True ) - # Verify error message - assert "Invalid app mode" in str(exc_info.value) - def test_generate_with_workflow_id_format_error( self, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 6712fe8454..50f5b7a8c0 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -263,6 +263,27 @@ class TestFileService: user=account, ) + def test_upload_file_allows_regular_punctuation_in_filename( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): + """ + Test file upload allows punctuation that is safe when stored as metadata. + """ + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = 'candidate?resume for "dify"|v2:.txt' + content = b"test content" + mimetype = "text/plain" + + upload_file = FileService(engine).upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file.name == filename + def test_upload_file_filename_too_long( self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index cc403ef5a2..dd743d46c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -163,7 +163,7 @@ class TestSavedMessageService: answer_unit_price=0.002, total_price=0.003, currency="USD", - status="success", + status="normal", ) db_session_with_containers.add(message) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index bfb23bac68..d8b43efeba 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -1090,20 +1090,19 @@ class TestWorkflowService: This test ensures that the service correctly handles feature validation for unsupported app modes, preventing invalid operations. + With EnumText, invalid values are rejected at the DB level during flush, + raising StatementError wrapping ValueError. """ # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - db_session_with_containers.commit() + # Act & Assert - EnumText validation rejects invalid values at DB flush + import sqlalchemy as sa - workflow_service = WorkflowService() - features = {"test": "value"} - - # Act & Assert - with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): - workflow_service.validate_features_structure(app_model=app, features=features) + with pytest.raises((ValueError, sa.exc.StatementError)): + db_session_with_containers.commit() def test_update_workflow_success(self, db_session_with_containers: Session): """ diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 8eb881258a..41d9fc8a29 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -110,7 +110,7 @@ class TestCleanDatasetTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index bc0ed3bd2b..69ed5b632d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -48,7 +48,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 8f47b48ae2..6f7d2c28b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -65,7 +65,7 @@ class TestDisableSegmentsFromIndexTask: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 3cdec70df7..c0ddc27286 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -118,7 +118,7 @@ class TestSendEmailCodeLoginMailTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index c3a6522e6d..6b5c304884 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -48,7 +48,7 @@ def make_message(): msg.query = "hello" msg.re_sign_file_url_answer = "" msg.user_feedback = MagicMock(rating=None) - msg.status = "success" + msg.status = "normal" msg.error = None return msg diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index f6db55db5b..eb19243225 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -200,10 +200,13 @@ class TestPluginUploadFromPkgApi: app.test_request_context("/", data=data, content_type="multipart/form-data"), patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), + patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): with pytest.raises(ValueError): method(api) + upload_pkg_mock.assert_not_called() + class TestPluginInstallFromPkgApi: def test_install_from_pkg(self, app): @@ -444,10 +447,13 @@ class TestPluginUploadFromBundleApi: ), patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), + patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): with pytest.raises(ValueError): method(api) + upload_bundle_mock.assert_not_called() + class TestPluginInstallFromGithubApi: def test_success(self, app): diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 1c096bfbcf..2bb425cdba 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -137,7 +137,7 @@ def test_message_list_mapping(app: Flask) -> None: {"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"}, message_file_obj, ], - status="success", + status="normal", error=None, message_metadata_dict={"meta": "value"}, extra_contents=[ diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index 9518c61202..f6d1edbaf0 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -5,8 +5,8 @@ import pytest from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit +from core.agent.errors import AgentMaxIterationError from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.nodes.agent.exc import AgentMaxIterationError class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 8843a8d505..299c9b31d2 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest +from core.agent.errors import AgentMaxIterationError from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueMessageFileEvent @@ -14,7 +15,6 @@ from dify_graph.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.nodes.agent.exc import AgentMaxIterationError # ============================== # Dummy Helper Classes diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index d8afd3b10a..108b740344 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -105,10 +105,10 @@ class TestWorkflowBasedAppRunner: from core.app.apps import workflow_app_runner - monkeypatch.setitem( - workflow_app_runner.NODE_TYPE_CLASSES_MAPPING, - NodeType.START, - {"1": _NodeCls}, + monkeypatch.setattr( + workflow_app_runner, + "resolve_workflow_node_class", + lambda **_kwargs: _NodeCls, ) monkeypatch.setattr( "core.app.apps.workflow_app_runner.load_into_variable_pool", diff --git a/api/tests/unit_tests/core/datasource/test_file_upload.py b/api/tests/unit_tests/core/datasource/test_file_upload.py index ad86190e00..63b86e64fc 100644 --- a/api/tests/unit_tests/core/datasource/test_file_upload.py +++ b/api/tests/unit_tests/core/datasource/test_file_upload.py @@ -35,7 +35,7 @@ TEST COVERAGE OVERVIEW: - Tests hash consistency and determinism 6. Invalid Filename Handling (TestInvalidFilenameHandling) - - Validates rejection of filenames with invalid characters (/, \\, :, *, ?, ", <, >, |) + - Validates rejection of filenames with path separators (/, \\) - Tests filename length truncation (max 200 characters) - Prevents path traversal attacks - Handles edge cases like empty filenames @@ -535,30 +535,23 @@ class TestInvalidFilenameHandling: @pytest.mark.parametrize( "invalid_char", - ["/", "\\", ":", "*", "?", '"', "<", ">", "|"], + ["/", "\\"], ) def test_filename_contains_invalid_characters(self, invalid_char): """Test detection of invalid characters in filename. - Security-critical test that validates rejection of dangerous filename characters. + Security-critical test that validates rejection of path separators. These characters are blocked because they: - / and \\ : Directory separators, could enable path traversal - - : : Drive letter separator on Windows, reserved character - - * and ? : Wildcards, could cause issues in file operations - - " : Quote character, could break command-line operations - - < and > : Redirection operators, command injection risk - - | : Pipe operator, command injection risk Blocking these characters prevents: - Path traversal attacks (../../etc/passwd) - - Command injection - - File system corruption - - Cross-platform compatibility issues + - ZIP entry traversal issues + - Ambiguous path handling """ # Arrange - Create filename with invalid character filename = f"test{invalid_char}file.txt" - # Define complete list of invalid characters - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act - Check if filename contains any invalid character has_invalid_char = any(c in filename for c in invalid_chars) @@ -570,7 +563,7 @@ class TestInvalidFilenameHandling: """Test that valid filenames pass validation.""" # Arrange filename = "valid_file-name_123.txt" - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act has_invalid_char = any(c in filename for c in invalid_chars) @@ -578,6 +571,16 @@ class TestInvalidFilenameHandling: # Assert assert has_invalid_char is False + @pytest.mark.parametrize("safe_char", [":", "*", "?", '"', "<", ">", "|"]) + def test_filename_allows_safe_metadata_characters(self, safe_char): + """Test that non-separator punctuation remains allowed in filenames.""" + filename = f"candidate{safe_char}resume.txt" + invalid_chars = ["/", "\\"] + + has_invalid_char = any(c in filename for c in invalid_chars) + + assert has_invalid_char is False + def test_extremely_long_filename_truncation(self): """Test handling of extremely long filenames.""" # Arrange @@ -904,7 +907,7 @@ class TestFilenameValidation: """Test that filenames with spaces are handled correctly.""" # Arrange filename = "my document with spaces.pdf" - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act - Check for invalid characters has_invalid = any(c in filename for c in invalid_chars) @@ -921,7 +924,7 @@ class TestFilenameValidation: "مستند.txt", # Arabic "ファイル.jpg", # Japanese ] - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act & Assert - Unicode should be allowed for filename in unicode_filenames: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b90c4935af..de3ccc4518 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -3730,7 +3730,7 @@ class TestDatasetRetrievalAdditionalHelpers: attachment_ids=None, dataset_ids=["d1"], app_id="a1", - user_from="web", + user_from="account", user_id="u1", ) mock_session.add_all.assert_not_called() @@ -3740,7 +3740,7 @@ class TestDatasetRetrievalAdditionalHelpers: attachment_ids=["f1"], dataset_ids=["d1", "d2"], app_id="a1", - user_from="web", + user_from="account", user_id="u1", ) mock_session.add_all.assert_called() diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 4d59affb99..5ceaa08893 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -5,6 +5,7 @@ from typing import Any from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.tool_parameter_cache import ToolParameterCache from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -112,37 +113,38 @@ def test_encrypt_tool_parameters(): def test_decrypt_tool_parameters_cache_hit_and_miss(): manager = _build_manager() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: - cache = cache_cls.return_value - cache.get.return_value = {"secret": "cached"} + with ( + patch.object(ToolParameterCache, "get", return_value={"secret": "cached"}), + patch.object(ToolParameterCache, "set") as mock_set, + ): assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"} - cache.set.assert_not_called() + mock_set.assert_not_called() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: - cache = cache_cls.return_value - cache.get.return_value = None - with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"): - decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) - - assert decrypted["secret"] == "dec" - cache.set.assert_called_once() + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch.object(ToolParameterCache, "set") as mock_set, + patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) + assert decrypted["secret"] == "dec" + mock_set.assert_called_once() def test_delete_tool_parameters_cache(): manager = _build_manager() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: + with patch.object(ToolParameterCache, "delete") as mock_delete: manager.delete_tool_parameters_cache() - cache_cls.return_value.delete.assert_called_once() + mock_delete.assert_called_once() def test_configuration_manager_decrypt_suppresses_errors(): manager = _build_manager() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: - cache = cache_cls.return_value - cache.get.return_value = None - with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): - decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) # decryption failure is suppressed, original value is retained. assert decrypted["secret"] == "enc" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 34e714a227..9e3574266c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,10 +11,10 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance +from core.workflow.nodes.agent import AgentNode from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.agent import AgentNode from dify_graph.nodes.code import CodeNode from dify_graph.nodes.document_extractor import DocumentExtractorNode from dify_graph.nodes.http_request import HttpRequestNode @@ -79,6 +79,14 @@ class MockNodeMixin: if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + if isinstance(self, AgentNode): + presentation_provider = MagicMock() + presentation_provider.get_icon.return_value = None + kwargs.setdefault("strategy_resolver", MagicMock()) + kwargs.setdefault("presentation_provider", presentation_provider) + kwargs.setdefault("runtime_support", MagicMock()) + kwargs.setdefault("message_transformer", MagicMock()) + super().__init__( id=id, config=config, diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 4a5f561c22..934e29546c 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -260,7 +260,11 @@ class TestDifyNodeFactoryCreateNode: factory.create_node({"id": "node-id", "data": {"type": "missing"}}) def test_rejects_missing_class_mapping(self, monkeypatch, factory): - monkeypatch.setattr(node_factory, "NODE_TYPE_CLASSES_MAPPING", {}) + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(side_effect=ValueError("No class mapping found for node type: start")), + ) with pytest.raises(ValueError, match="No class mapping found for node type: start"): factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) @@ -268,8 +272,8 @@ class TestDifyNodeFactoryCreateNode: def test_rejects_missing_latest_class(self, monkeypatch, factory): monkeypatch.setattr( node_factory, - "NODE_TYPE_CLASSES_MAPPING", - {NodeType.START: {node_factory.LATEST_VERSION: None}}, + "resolve_workflow_node_class", + MagicMock(side_effect=ValueError("No latest version class found for node type: start")), ) with pytest.raises(ValueError, match="No latest version class found for node type: start"): @@ -281,13 +285,8 @@ class TestDifyNodeFactoryCreateNode: matched_node_class = MagicMock(return_value=matched_node) monkeypatch.setattr( node_factory, - "NODE_TYPE_CLASSES_MAPPING", - { - NodeType.START: { - node_factory.LATEST_VERSION: latest_node_class, - "9": matched_node_class, - } - }, + "resolve_workflow_node_class", + MagicMock(return_value=matched_node_class), ) result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) @@ -306,8 +305,8 @@ class TestDifyNodeFactoryCreateNode: latest_node_class = MagicMock(return_value=latest_node) monkeypatch.setattr( node_factory, - "NODE_TYPE_CLASSES_MAPPING", - {NodeType.START: {node_factory.LATEST_VERSION: latest_node_class}}, + "resolve_workflow_node_class", + MagicMock(return_value=latest_node_class), ) result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) @@ -338,8 +337,8 @@ class TestDifyNodeFactoryCreateNode: constructor = MagicMock(name=constructor_name, return_value=created_node) monkeypatch.setattr( node_factory, - "NODE_TYPE_CLASSES_MAPPING", - {node_type: {node_factory.LATEST_VERSION: constructor}}, + "resolve_workflow_node_class", + MagicMock(return_value=constructor), ) if constructor_name == "HumanInputNode": @@ -411,8 +410,8 @@ class TestDifyNodeFactoryCreateNode: constructor = MagicMock(name=constructor_name, return_value=created_node) monkeypatch.setattr( node_factory, - "NODE_TYPE_CLASSES_MAPPING", - {node_type: {node_factory.LATEST_VERSION: constructor}}, + "resolve_workflow_node_class", + MagicMock(return_value=constructor), ) llm_init_kwargs = { "credentials_provider": sentinel.credentials_provider, diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index fe211fb76a..68e42894fc 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -400,8 +400,8 @@ class TestWorkflowEntryHelpers: def test_run_free_node_rejects_missing_node_class(self, monkeypatch): monkeypatch.setattr( workflow_entry, - "NODE_TYPE_CLASSES_MAPPING", - {NodeType.PARAMETER_EXTRACTOR: {"1": None}}, + "resolve_workflow_node_class", + MagicMock(return_value=None), ) with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"): @@ -432,8 +432,8 @@ class TestWorkflowEntryHelpers: dify_node_factory.create_node.return_value = FakeNode() monkeypatch.setattr( workflow_entry, - "NODE_TYPE_CLASSES_MAPPING", - {NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}}, + "resolve_workflow_node_class", + MagicMock(return_value=FakeNodeClass), ) with ( @@ -518,8 +518,8 @@ class TestWorkflowEntryHelpers: dify_node_factory.create_node.return_value = FakeNode() monkeypatch.setattr( workflow_entry, - "NODE_TYPE_CLASSES_MAPPING", - {NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}}, + "resolve_workflow_node_class", + MagicMock(return_value=FakeNodeClass), ) with ( diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index cc311d447f..1726fc2e8b 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -98,7 +98,7 @@ class TestAccountModelValidation: ) # Assert - assert account.status == "active" + assert account.status == AccountStatus.ACTIVE def test_account_get_status_method(self): """Test the get_status method returns AccountStatus enum.""" @@ -106,7 +106,7 @@ class TestAccountModelValidation: account = Account( name="Test User", email="test@example.com", - status="pending", + status=AccountStatus.PENDING, ) # Act diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py index 6097bcbd61..4bfdba87a0 100644 --- a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -13,6 +13,7 @@ from datetime import datetime from unittest.mock import Mock, create_autospec, patch import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table from libs.archive_storage import ArchiveStorageNotConfiguredError from models.trigger import WorkflowTriggerLog @@ -127,10 +128,41 @@ class WorkflowRunRestoreTestDataFactory: if tables_data is None: tables_data = { - "workflow_runs": [{"id": "run-123", "tenant_id": "tenant-123"}], + "workflow_runs": [ + { + "id": "run-123", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + } + ], "workflow_app_logs": [ - {"id": "log-1", "workflow_run_id": "run-123"}, - {"id": "log-2", "workflow_run_id": "run-123"}, + { + "id": "log-1", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "workflow_run_id": "run-123", + "created_from": "app", + "created_by_role": "account", + "created_by": "user-123", + }, + { + "id": "log-2", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "workflow_run_id": "run-123", + "created_from": "app", + "created_by_role": "account", + "created_by": "user-123", + }, ], } @@ -406,14 +438,48 @@ class TestGetModelColumnInfo: assert "created_by" in column_names assert "status" in column_names - # WorkflowRun model has no required columns (all have defaults or are auto-generated) - assert required_columns == set() + # Columns without defaults should be required for restore inserts. + assert { + "tenant_id", + "app_id", + "workflow_id", + "type", + "triggered_from", + "version", + "status", + "created_by_role", + "created_by", + }.issubset(required_columns) + assert "id" not in required_columns + assert "created_at" not in required_columns # Check columns with defaults or server defaults assert "id" in non_nullable_with_default assert "created_at" in non_nullable_with_default assert "elapsed_time" in non_nullable_with_default assert "total_tokens" in non_nullable_with_default + assert "tenant_id" not in non_nullable_with_default + + def test_non_pk_auto_autoincrement_column_is_still_required(self): + """`autoincrement='auto'` should not mark non-PK columns as defaulted.""" + restore = WorkflowRunRestore() + + test_table = Table( + "test_autoincrement", + MetaData(), + Column("id", Integer, primary_key=True, autoincrement=True), + Column("required_field", String(255), nullable=False), + Column("defaulted_field", String(255), nullable=False, default="x"), + ) + + class MockModel: + __table__ = test_table + + _, required_columns, non_nullable_with_default = restore._get_model_column_info(MockModel) + + assert required_columns == {"required_field"} + assert "id" in non_nullable_with_default + assert "defaulted_field" in non_nullable_with_default # --------------------------------------------------------------------------- @@ -465,7 +531,32 @@ class TestRestoreTableRecords: mock_stmt.on_conflict_do_nothing.return_value = mock_stmt mock_pg_insert.return_value = mock_stmt - records = [{"id": "test1", "tenant_id": "tenant-123"}, {"id": "test2", "tenant_id": "tenant-123"}] + records = [ + { + "id": "test1", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + }, + { + "id": "test2", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + }, + ] result = restore._restore_table_records(mock_session, "workflow_runs", records, schema_version="1.0") @@ -477,8 +568,7 @@ class TestRestoreTableRecords: restore = WorkflowRunRestore() mock_session = Mock() - # Since WorkflowRun has no required columns, we need to test with a different model - # Let's test with a mock model that has required columns + # Use a dedicated mock model to isolate required-column validation behavior. mock_model = Mock() # Mock a required column @@ -965,6 +1055,13 @@ class TestIntegration: "id": "run-123", "tenant_id": "tenant-123", "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", "created_at": "2024-01-01T12:00:00", } ], diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 8820a1acc0..5ce0e6f140 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -1001,12 +1001,12 @@ class TestWorkflowService: Used by the UI to populate the node palette and provide sensible defaults when users add new nodes to their workflow. """ - with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping: + with patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping: # Mock node class with default config mock_node_class = MagicMock() mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.items.return_value = [(NodeType.LLM, {"latest": mock_node_class})] + mock_mapping.return_value = {NodeType.LLM: {"latest": mock_node_class}} with patch("services.workflow_service.LATEST_VERSION", "latest"): result = workflow_service.get_default_block_configs() @@ -1025,7 +1025,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1036,10 +1036,10 @@ class TestWorkflowService: mock_http_node_class.get_default_config.return_value = {"type": "http-request", "config": {}} mock_llm_node_class = MagicMock() mock_llm_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.items.return_value = [ - (NodeType.HTTP_REQUEST, {"latest": mock_http_node_class}), - (NodeType.LLM, {"latest": mock_llm_node_class}), - ] + mock_mapping.return_value = { + NodeType.HTTP_REQUEST: {"latest": mock_http_node_class}, + NodeType.LLM: {"latest": mock_llm_node_class}, + } result = workflow_service.get_default_block_configs() @@ -1060,7 +1060,7 @@ class TestWorkflowService: This includes default values for all required and optional parameters. """ with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), ): # Mock node class with default config @@ -1069,8 +1069,7 @@ class TestWorkflowService: mock_node_class.get_default_config.return_value = mock_config # Create a mock mapping that includes NodeType.LLM - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {NodeType.LLM: {"latest": mock_node_class}} result = workflow_service.get_default_block_config(NodeType.LLM.value) @@ -1079,9 +1078,8 @@ class TestWorkflowService: def test_get_default_block_config_invalid_node_type(self, workflow_service): """Test get_default_block_config returns empty dict for invalid node type.""" - with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping: - # Mock mapping to not contain the node type - mock_mapping.__contains__.return_value = False + with patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping: + mock_mapping.return_value = {} # Use a valid NodeType but one that's not in the mapping result = workflow_service.get_default_block_config(NodeType.LLM.value) @@ -1100,7 +1098,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1110,8 +1108,7 @@ class TestWorkflowService: mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {NodeType.HTTP_REQUEST: {"latest": mock_node_class}} result = workflow_service.get_default_block_config(NodeType.HTTP_REQUEST.value) @@ -1132,15 +1129,14 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch("services.workflow_service.build_http_request_config") as mock_build_config, ): mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {NodeType.HTTP_REQUEST: {"latest": mock_node_class}} result = workflow_service.get_default_block_config( NodeType.HTTP_REQUEST.value, @@ -1155,8 +1151,8 @@ class TestWorkflowService: def test_get_default_block_config_http_request_malformed_config_raises_value_error(self, workflow_service): with ( patch( - "services.workflow_service.NODE_TYPE_CLASSES_MAPPING", - {NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}}, + "services.workflow_service.get_workflow_node_type_classes_mapping", + return_value={NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}}, ), patch("services.workflow_service.LATEST_VERSION", "latest"), ): diff --git a/api/uv.lock b/api/uv.lock index 1d03d5a360..555a980d97 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1698,7 +1698,7 @@ requires-dist = [ { name = "pydantic", specifier = "~=2.12.5" }, { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, - { name = "pyjwt", specifier = "~=2.11.0" }, + { name = "pyjwt", specifier = "~=2.12.0" }, { name = "pypdfium2", specifier = "==5.2.0" }, { name = "python-docx", specifier = "~=1.2.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, @@ -2051,11 +2051,11 @@ wheels = [ [[package]] name = "fickling" -version = "0.1.9" +version = "0.1.10" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/25/bd/ca7127df0201596b0b30f9ab3d36e565bb9d6f8f4da1560758b817e81b65/fickling-0.1.9.tar.gz", hash = "sha256:bb518c2fd833555183bc46b6903bb4022f3ae0436a69c3fb149cfc75eebaac33", size = 336940, upload-time = "2026-03-03T23:32:19.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9f/06/1818b8f52267599e54041349c553d5894e17ec8a539a246eb3f9eaf05629/fickling-0.1.10.tar.gz", hash = "sha256:8c8b76abd29936f1a5932e4087b8c8becb2d7ab1cf08549e63519ebcb2f71644", size = 338062, upload-time = "2026-03-13T16:34:29.287Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/92/49/c597bad508c74917901432b41ae5a8f036839a7fb8d0d29a89765f5d3643/fickling-0.1.9-py3-none-any.whl", hash = "sha256:ccc3ce3b84733406ade2fe749717f6e428047335157c6431eefd3e7e970a06d1", size = 52786, upload-time = "2026-03-03T23:32:17.533Z" }, + { url = "https://files.pythonhosted.org/packages/05/86/620960dff970da5311f05e25fc045dac8495557d51030e5a0827084b18fd/fickling-0.1.10-py3-none-any.whl", hash = "sha256:962c35c38ece1b3632fc119c0f4cb1eebc02dc6d65bfd93a1803afd42ca91d25", size = 52853, upload-time = "2026-03-13T16:34:27.821Z" }, ] [[package]] @@ -5089,11 +5089,11 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.11.0" +version = "2.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/10/e8192be5f38f3e8e7e046716de4cae33d56fd5ae08927a823bb916be36c1/pyjwt-2.12.0.tar.gz", hash = "sha256:2f62390b667cd8257de560b850bb5a883102a388829274147f1d724453f8fb02", size = 102511, upload-time = "2026-03-12T17:15:30.831Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, + { url = "https://files.pythonhosted.org/packages/15/70/70f895f404d363d291dcf62c12c85fdd47619ad9674ac0f53364d035925a/pyjwt-2.12.0-py3-none-any.whl", hash = "sha256:9bb459d1bdd0387967d287f5656bf7ec2b9a26645d1961628cda1764e087fd6e", size = 29700, upload-time = "2026-03-12T17:15:29.257Z" }, ] [package.optional-dependencies] diff --git a/web/__tests__/share/text-generation-index-flow.test.tsx b/web/__tests__/share/text-generation-index-flow.test.tsx new file mode 100644 index 0000000000..3292474bec --- /dev/null +++ b/web/__tests__/share/text-generation-index-flow.test.tsx @@ -0,0 +1,235 @@ +import type { AccessMode } from '@/models/access-control' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import TextGeneration from '@/app/components/share/text-generation' + +const useSearchParamsMock = vi.fn(() => new URLSearchParams()) + +vi.mock('next/navigation', () => ({ + useSearchParams: () => useSearchParamsMock(), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: vi.fn(() => 'pc'), + MediaType: { pc: 'pc', pad: 'pad', mobile: 'mobile' }, +})) + +vi.mock('@/hooks/use-app-favicon', () => ({ + useAppFavicon: vi.fn(), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('@/i18n-config/client', () => ({ + changeLanguage: vi.fn(() => Promise.resolve()), +})) + +vi.mock('@/app/components/share/text-generation/run-once', () => ({ + default: ({ + inputs, + onInputsChange, + onSend, + runControl, + }: { + inputs: Record + onInputsChange: (inputs: Record) => void + onSend: () => void + runControl?: { isStopping: boolean } | null + }) => ( +
+ {String(inputs.name ?? '')} + + + {runControl ? 'stop-ready' : 'idle'} +
+ ), +})) + +vi.mock('@/app/components/share/text-generation/run-batch', () => ({ + default: ({ onSend }: { onSend: (data: string[][]) => void }) => ( + + ), +})) + +vi.mock('@/app/components/app/text-generate/saved-items', () => ({ + default: ({ list }: { list: { id: string }[] }) =>
{list.length}
, +})) + +vi.mock('@/app/components/share/text-generation/menu-dropdown', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/share/text-generation/result', () => { + const MockResult = ({ + isCallBatchAPI, + onRunControlChange, + onRunStart, + taskId, + }: { + isCallBatchAPI: boolean + onRunControlChange?: (control: { onStop: () => void, isStopping: boolean } | null) => void + onRunStart: () => void + taskId?: number + }) => { + const runControlRef = React.useRef(false) + + React.useEffect(() => { + onRunStart() + }, [onRunStart]) + + React.useEffect(() => { + if (!isCallBatchAPI && !runControlRef.current) { + runControlRef.current = true + onRunControlChange?.({ onStop: vi.fn(), isStopping: false }) + } + }, [isCallBatchAPI, onRunControlChange]) + + return
+ } + + return { + default: MockResult, + } +}) + +const fetchSavedMessageMock = vi.fn() + +vi.mock('@/service/share', async () => { + const actual = await vi.importActual('@/service/share') + return { + ...actual, + fetchSavedMessage: (...args: Parameters) => fetchSavedMessageMock(...args), + removeMessage: vi.fn(), + saveMessage: vi.fn(), + } +}) + +const mockSystemFeatures = { + branding: { + enabled: false, + workspace_logo: null, + }, +} + +const mockWebAppState = { + appInfo: { + app_id: 'app-123', + site: { + title: 'Text Generation', + description: 'Share description', + default_language: 'en-US', + icon_type: 'emoji', + icon: 'robot', + icon_background: '#fff', + icon_url: '', + }, + custom_config: { + remove_webapp_brand: false, + replace_webapp_logo: '', + }, + }, + appParams: { + user_input_form: [ + { + 'text-input': { + label: 'Name', + variable: 'name', + required: true, + max_length: 48, + default: '', + hide: false, + }, + }, + ], + more_like_this: { + enabled: true, + }, + file_upload: { + enabled: false, + number_limits: 2, + detail: 'low', + allowed_upload_methods: ['local_file'], + }, + text_to_speech: { + enabled: true, + }, + system_parameters: { + image_file_size_limit: 10, + }, + }, + webAppAccessMode: 'public' as AccessMode, +} + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: typeof mockSystemFeatures }) => unknown) => + selector({ systemFeatures: mockSystemFeatures }), +})) + +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: typeof mockWebAppState) => unknown) => selector(mockWebAppState), +})) + +describe('TextGeneration', () => { + beforeEach(() => { + vi.clearAllMocks() + useSearchParamsMock.mockReturnValue(new URLSearchParams()) + fetchSavedMessageMock.mockResolvedValue({ + data: [{ id: 'saved-1' }, { id: 'saved-2' }], + }) + }) + + it('should switch between create, batch, and saved tabs after app state loads', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + expect(screen.getByTestId('run-once-input-name')).toHaveTextContent('') + + fireEvent.click(screen.getByRole('button', { name: 'change-inputs' })) + await waitFor(() => { + expect(screen.getByTestId('run-once-input-name')).toHaveTextContent('Gamma') + }) + + fireEvent.click(screen.getByTestId('tab-header-item-batch')) + expect(screen.getByRole('button', { name: 'run-batch' })).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('tab-header-item-saved')) + expect(screen.getByTestId('saved-items-mock')).toHaveTextContent('2') + + fireEvent.click(screen.getByTestId('tab-header-item-create')) + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + + it('should wire single-run stop control and clear it when batch execution starts', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByRole('button', { name: 'run-once' })) + await waitFor(() => { + expect(screen.getByText('stop-ready')).toBeInTheDocument() + }) + expect(screen.getByTestId('result-single')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('tab-header-item-batch')) + fireEvent.click(screen.getByRole('button', { name: 'run-batch' })) + await waitFor(() => { + expect(screen.getByText('idle')).toBeInTheDocument() + }) + expect(screen.getByTestId('result-task-1')).toBeInTheDocument() + expect(screen.getByTestId('result-task-2')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index d98e02ad57..b849b4f015 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -6,7 +6,7 @@ import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' import type { AppDetailResponse } from '@/models/app' import type { AppSSO } from '@/types/app' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { Plan } from '@/app/components/billing/type' import { baseProviderContextValue } from '@/context/provider-context' import { AppModeEnum } from '@/types/app' @@ -131,6 +131,10 @@ describe('SettingsModal', () => { }) }) + afterEach(() => { + vi.useRealTimers() + }) + it('should render the modal and expose the expanded settings section', async () => { renderSettingsModal() expect(screen.getByText('appOverview.overview.appInfo.settings.title')).toBeInTheDocument() @@ -212,4 +216,54 @@ describe('SettingsModal', () => { })) expect(mockOnClose).toHaveBeenCalled() }) + + it('should clear the delayed hide-more timer when the modal unmounts after closing', () => { + vi.useFakeTimers() + const clearTimeoutSpy = vi.spyOn(globalThis, 'clearTimeout') + const { unmount } = renderSettingsModal() + + fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry')) + fireEvent.click(screen.getByText('common.operation.cancel')) + unmount() + + expect(clearTimeoutSpy).toHaveBeenCalled() + vi.runAllTimers() + }) + + it('should replace the pending hide-more timer and clear the ref after the timeout completes', async () => { + const hideCallbacks: Array<() => void> = [] + const originalSetTimeout = globalThis.setTimeout + const setTimeoutSpy = vi.spyOn(globalThis, 'setTimeout').mockImplementation((( + callback: TimerHandler, + delay?: number, + ...args: unknown[] + ) => { + if (delay === 200) { + hideCallbacks.push(() => { + if (typeof callback === 'function') + callback(...args) + }) + return hideCallbacks.length as unknown as ReturnType + } + + return originalSetTimeout(callback, delay, ...args) + }) as unknown as typeof setTimeout) + const clearTimeoutSpy = vi.spyOn(globalThis, 'clearTimeout') + renderSettingsModal() + + act(() => { + fireEvent.click(screen.getByText('common.operation.cancel')) + fireEvent.click(screen.getByText('common.operation.cancel')) + }) + + expect(clearTimeoutSpy).toHaveBeenCalled() + expect(hideCallbacks.length).toBeGreaterThanOrEqual(2) + + act(() => { + hideCallbacks.at(-1)?.() + }) + + setTimeoutSpy.mockRestore() + clearTimeoutSpy.mockRestore() + }) }) diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 20461dda7e..f7c9e309ab 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -6,7 +6,7 @@ import type { AppIconType, AppSSO, Language } from '@/types/app' import { RiArrowRightSLine, RiCloseLine } from '@remixicon/react' import Link from 'next/link' import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { Trans, useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import AppIcon from '@/app/components/base/app-icon' @@ -99,6 +99,7 @@ const SettingsModal: FC = ({ const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) const { t } = useTranslation() + const hideMoreTimerRef = useRef | null>(null) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [appIcon, setAppIcon] = useState( @@ -137,10 +138,22 @@ const SettingsModal: FC = ({ : { type: 'emoji', icon, background: icon_background! }) }, [appInfo, chat_color_theme, chat_color_theme_inverted, copyright, custom_disclaimer, default_language, description, icon, icon_background, icon_type, icon_url, privacy_policy, show_workflow_steps, title, use_icon_as_answer_icon]) + useEffect(() => { + return () => { + if (hideMoreTimerRef.current) { + clearTimeout(hideMoreTimerRef.current) + hideMoreTimerRef.current = null + } + } + }, []) + const onHide = () => { onClose() - setTimeout(() => { + if (hideMoreTimerRef.current) + clearTimeout(hideMoreTimerRef.current) + hideMoreTimerRef.current = setTimeout(() => { setIsShowMore(false) + hideMoreTimerRef.current = null }, 200) } diff --git a/web/app/components/share/text-generation/__tests__/text-generation-result-panel.spec.tsx b/web/app/components/share/text-generation/__tests__/text-generation-result-panel.spec.tsx new file mode 100644 index 0000000000..60c109feec --- /dev/null +++ b/web/app/components/share/text-generation/__tests__/text-generation-result-panel.spec.tsx @@ -0,0 +1,190 @@ +import type { PromptConfig } from '@/models/debug' +import type { SiteInfo } from '@/models/share' +import type { VisionSettings } from '@/types/app' +import { fireEvent, render, screen } from '@testing-library/react' +import { AppSourceType } from '@/service/share' +import { Resolution, TransferMethod } from '@/types/app' +import TextGenerationResultPanel from '../text-generation-result-panel' +import { TaskStatus } from '../types' + +const resPropsSpy = vi.fn() +const resDownloadPropsSpy = vi.fn() + +vi.mock('@/app/components/share/text-generation/result', () => ({ + default: (props: Record) => { + resPropsSpy(props) + return
+ }, +})) + +vi.mock('@/app/components/share/text-generation/run-batch/res-download', () => ({ + default: (props: Record) => { + resDownloadPropsSpy(props) + return
+ }, +})) + +const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + { key: 'name', name: 'Name', type: 'string', required: true }, + ], +} + +const siteInfo: SiteInfo = { + title: 'Text Generation', + description: 'Share description', + icon_type: 'emoji', + icon: 'robot', +} + +const visionConfig: VisionSettings = { + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], +} + +const batchTasks = [ + { + id: 1, + status: TaskStatus.completed, + params: { inputs: { name: 'Alpha' } }, + }, + { + id: 2, + status: TaskStatus.failed, + params: { inputs: { name: 'Beta' } }, + }, +] + +const baseProps = { + allFailedTaskList: [], + allSuccessTaskList: [], + allTaskList: batchTasks, + appId: 'app-123', + appSourceType: AppSourceType.webApp, + completionFiles: [], + controlRetry: 88, + controlSend: 77, + controlStopResponding: 66, + exportRes: [{ 'Name': 'Alpha', 'share.generation.completionResult': 'Done' }], + handleCompleted: vi.fn(), + handleRetryAllFailedTask: vi.fn(), + handleSaveMessage: vi.fn(async () => {}), + inputs: { name: 'Alice' }, + isCallBatchAPI: false, + isPC: true, + isShowResultPanel: true, + isWorkflow: false, + moreLikeThisEnabled: true, + noPendingTask: true, + onHideResultPanel: vi.fn(), + onRunControlChange: vi.fn(), + onRunStart: vi.fn(), + onShowResultPanel: vi.fn(), + promptConfig, + resultExisted: true, + showTaskList: batchTasks, + siteInfo, + textToSpeechEnabled: true, + visionConfig, +} + +describe('TextGenerationResultPanel', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render a single result in run-once mode and pass non-batch props', () => { + render() + + expect(screen.getByTestId('res-single')).toBeInTheDocument() + expect(resPropsSpy).toHaveBeenCalledWith(expect.objectContaining({ + appId: 'app-123', + appSourceType: AppSourceType.webApp, + completionFiles: [], + controlSend: 77, + controlStopResponding: 66, + hideInlineStopButton: true, + inputs: { name: 'Alice' }, + isCallBatchAPI: false, + moreLikeThisEnabled: true, + taskId: undefined, + })) + expect(screen.queryByTestId('res-download-mock')).not.toBeInTheDocument() + }) + + it('should render batch results, download entry, loading area, and retry banner', () => { + const handleRetryAllFailedTask = vi.fn() + + render( + , + ) + + expect(screen.getByTestId('res-1')).toBeInTheDocument() + expect(screen.getByTestId('res-2')).toBeInTheDocument() + expect(resPropsSpy).toHaveBeenNthCalledWith(1, expect.objectContaining({ + inputs: { name: 'Alpha' }, + isError: false, + controlRetry: 0, + taskId: 1, + onRunControlChange: undefined, + })) + expect(resPropsSpy).toHaveBeenNthCalledWith(2, expect.objectContaining({ + inputs: { name: 'Beta' }, + isError: true, + controlRetry: 88, + taskId: 2, + })) + expect(screen.getByText('share.generation.executions:{"num":2}')).toBeInTheDocument() + expect(screen.getByTestId('res-download-mock')).toBeInTheDocument() + expect(resDownloadPropsSpy).toHaveBeenCalledWith(expect.objectContaining({ + isMobile: false, + values: baseProps.exportRes, + })) + expect(screen.getByText('share.generation.batchFailed.info:{"num":1}')).toBeInTheDocument() + expect(screen.getByText('share.generation.batchFailed.retry')).toBeInTheDocument() + expect(screen.getByRole('status', { name: 'appApi.loading' })).toBeInTheDocument() + + fireEvent.click(screen.getByText('share.generation.batchFailed.retry')) + expect(handleRetryAllFailedTask).toHaveBeenCalledTimes(1) + }) + + it('should toggle mobile result panel handle between show and hide actions', () => { + const onHideResultPanel = vi.fn() + const onShowResultPanel = vi.fn() + const { rerender } = render( + , + ) + + fireEvent.click(document.querySelector('.cursor-grab') as HTMLElement) + expect(onHideResultPanel).toHaveBeenCalledTimes(1) + + rerender( + , + ) + + fireEvent.click(document.querySelector('.cursor-grab') as HTMLElement) + expect(onShowResultPanel).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/share/text-generation/__tests__/text-generation-sidebar.spec.tsx b/web/app/components/share/text-generation/__tests__/text-generation-sidebar.spec.tsx new file mode 100644 index 0000000000..6ff46d94c7 --- /dev/null +++ b/web/app/components/share/text-generation/__tests__/text-generation-sidebar.spec.tsx @@ -0,0 +1,261 @@ +import type { ComponentProps } from 'react' +import type { PromptConfig, SavedMessage } from '@/models/debug' +import type { SiteInfo } from '@/models/share' +import type { VisionSettings } from '@/types/app' +import { fireEvent, render, screen } from '@testing-library/react' +import { AccessMode } from '@/models/access-control' +import { Resolution, TransferMethod } from '@/types/app' +import { defaultSystemFeatures } from '@/types/feature' +import TextGenerationSidebar from '../text-generation-sidebar' + +const runOncePropsSpy = vi.fn() +const runBatchPropsSpy = vi.fn() +const savedItemsPropsSpy = vi.fn() + +vi.mock('@/app/components/share/text-generation/run-once', () => ({ + default: (props: Record) => { + runOncePropsSpy(props) + return
+ }, +})) + +vi.mock('@/app/components/share/text-generation/run-batch', () => ({ + default: (props: Record) => { + runBatchPropsSpy(props) + return
+ }, +})) + +vi.mock('@/app/components/app/text-generate/saved-items', () => ({ + default: (props: { onStartCreateContent: () => void, list: Array<{ id: string }> }) => { + savedItemsPropsSpy(props) + return ( +
+ {props.list.length} + +
+ ) + }, +})) + +vi.mock('@/app/components/share/text-generation/menu-dropdown', () => ({ + default: () =>
, +})) + +const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + { key: 'name', name: 'Name', type: 'string', required: true }, + ], +} + +const savedMessages: SavedMessage[] = [ + { id: 'saved-1', answer: 'Answer 1' }, + { id: 'saved-2', answer: 'Answer 2' }, +] + +const siteInfo: SiteInfo = { + title: 'Text Generation', + description: 'Share description', + icon_type: 'emoji', + icon: 'robot', + icon_background: '#fff', + icon_url: '', +} + +const visionConfig: VisionSettings = { + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], +} + +const baseProps: ComponentProps = { + accessMode: AccessMode.PUBLIC, + allTasksRun: true, + currentTab: 'create', + customConfig: { + remove_webapp_brand: false, + replace_webapp_logo: '', + }, + inputs: { name: 'Alice' }, + inputsRef: { current: { name: 'Alice' } }, + isInstalledApp: false, + isPC: true, + isWorkflow: false, + onBatchSend: vi.fn(), + onInputsChange: vi.fn(), + onRemoveSavedMessage: vi.fn(async () => {}), + onRunOnceSend: vi.fn(), + onTabChange: vi.fn(), + onVisionFilesChange: vi.fn(), + promptConfig, + resultExisted: false, + runControl: null, + savedMessages, + siteInfo, + systemFeatures: defaultSystemFeatures, + textToSpeechConfig: { enabled: true }, + visionConfig, +} + +const renderSidebar = (overrides: Partial = {}) => { + return render() +} + +describe('TextGenerationSidebar', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render create tab content and pass orchestration props to RunOnce', () => { + renderSidebar() + + expect(screen.getByText('Text Generation')).toBeInTheDocument() + expect(screen.getByText('Share description')).toBeInTheDocument() + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + expect(runOncePropsSpy).toHaveBeenCalledWith(expect.objectContaining({ + inputs: { name: 'Alice' }, + promptConfig, + runControl: null, + visionConfig, + })) + expect(screen.queryByTestId('saved-items-mock')).not.toBeInTheDocument() + }) + + it('should render batch tab and hide saved tab for workflow apps', () => { + renderSidebar({ + currentTab: 'batch', + isWorkflow: true, + }) + + expect(screen.getByTestId('run-batch-mock')).toBeInTheDocument() + expect(runBatchPropsSpy).toHaveBeenCalledWith(expect.objectContaining({ + vars: promptConfig.prompt_variables, + isAllFinished: true, + })) + expect(screen.queryByTestId('tab-header-item-saved')).not.toBeInTheDocument() + }) + + it('should render saved items and allow switching back to create tab', () => { + const onTabChange = vi.fn() + + renderSidebar({ + currentTab: 'saved', + onTabChange, + }) + + expect(screen.getByTestId('saved-items-mock')).toBeInTheDocument() + expect(savedItemsPropsSpy).toHaveBeenCalledWith(expect.objectContaining({ + list: baseProps.savedMessages, + isShowTextToSpeech: true, + })) + + fireEvent.click(screen.getByRole('button', { name: 'back-to-create' })) + expect(onTabChange).toHaveBeenCalledWith('create') + }) + + it('should prefer workspace branding and hide powered-by block when branding is removed', () => { + const { rerender } = renderSidebar({ + systemFeatures: { + ...defaultSystemFeatures, + branding: { + ...defaultSystemFeatures.branding, + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + }, + }, + }) + + const brandingLogo = screen.getByRole('img', { name: 'logo' }) + expect(brandingLogo).toHaveAttribute('src', 'https://example.com/workspace-logo.png') + + rerender( + , + ) + + expect(screen.queryByText('share.chat.poweredBy')).not.toBeInTheDocument() + }) + + it('should render mobile installed-app layout without saved badge when no saved messages exist', () => { + const { container } = renderSidebar({ + accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + isInstalledApp: true, + isPC: false, + resultExisted: false, + savedMessages: [], + siteInfo: { + ...siteInfo, + description: '', + icon_background: '', + }, + }) + + const root = container.firstElementChild as HTMLElement + const header = root.children[0] as HTMLElement + const body = root.children[1] as HTMLElement + + expect(root).toHaveClass('rounded-l-2xl') + expect(root).not.toHaveClass('h-[calc(100%_-_64px)]') + expect(header).toHaveClass('p-4', 'pb-0') + expect(body).toHaveClass('px-4') + expect(screen.queryByText('Share description')).not.toBeInTheDocument() + }) + + it('should render mobile saved tab with compact spacing and no text-to-speech flag', () => { + const { container } = renderSidebar({ + accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + currentTab: 'saved', + isPC: false, + resultExisted: true, + textToSpeechConfig: null, + }) + + const root = container.firstElementChild as HTMLElement + const body = root.children[1] as HTMLElement + const footer = root.children[2] as HTMLElement + + expect(root).toHaveClass('h-[calc(100%_-_64px)]') + expect(body).toHaveClass('px-4') + expect(footer).toHaveClass('px-4', 'rounded-b-2xl') + expect(savedItemsPropsSpy).toHaveBeenCalledWith(expect.objectContaining({ + className: expect.stringContaining('mt-4'), + isShowTextToSpeech: undefined, + })) + }) + + it('should round the mobile panel body and hide branding when the webapp brand is removed', () => { + const { container } = renderSidebar({ + isPC: false, + resultExisted: true, + customConfig: { + remove_webapp_brand: true, + replace_webapp_logo: '', + }, + }) + + const root = container.firstElementChild as HTMLElement + const body = root.children[1] as HTMLElement + + expect(body).toHaveClass('rounded-b-2xl') + expect(screen.queryByText('share.chat.poweredBy')).not.toBeInTheDocument() + }) + + it('should render the custom webapp logo when workspace branding is unavailable', () => { + renderSidebar({ + customConfig: { + remove_webapp_brand: false, + replace_webapp_logo: 'https://example.com/custom-logo.png', + }, + }) + + const brandingLogo = screen.getByRole('img', { name: 'logo' }) + expect(brandingLogo).toHaveAttribute('src', 'https://example.com/custom-logo.png') + }) +}) diff --git a/web/app/components/share/text-generation/hooks/__tests__/use-text-generation-app-state.spec.ts b/web/app/components/share/text-generation/hooks/__tests__/use-text-generation-app-state.spec.ts new file mode 100644 index 0000000000..3339dce403 --- /dev/null +++ b/web/app/components/share/text-generation/hooks/__tests__/use-text-generation-app-state.spec.ts @@ -0,0 +1,298 @@ +import { act, renderHook, waitFor } from '@testing-library/react' +import { AppSourceType } from '@/service/share' +import { useTextGenerationAppState } from '../use-text-generation-app-state' + +const { + changeLanguageMock, + fetchSavedMessageMock, + notifyMock, + removeMessageMock, + saveMessageMock, + useAppFaviconMock, + useDocumentTitleMock, +} = vi.hoisted(() => ({ + changeLanguageMock: vi.fn(() => Promise.resolve()), + fetchSavedMessageMock: vi.fn(), + notifyMock: vi.fn(), + removeMessageMock: vi.fn(), + saveMessageMock: vi.fn(), + useAppFaviconMock: vi.fn(), + useDocumentTitleMock: vi.fn(), +})) + +vi.mock('@/app/components/base/toast', () => ({ + default: { + notify: notifyMock, + }, +})) + +vi.mock('@/hooks/use-app-favicon', () => ({ + useAppFavicon: useAppFaviconMock, +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: useDocumentTitleMock, +})) + +vi.mock('@/i18n-config/client', () => ({ + changeLanguage: changeLanguageMock, +})) + +vi.mock('@/service/share', async () => { + const actual = await vi.importActual('@/service/share') + return { + ...actual, + fetchSavedMessage: (...args: Parameters) => fetchSavedMessageMock(...args), + removeMessage: (...args: Parameters) => removeMessageMock(...args), + saveMessage: (...args: Parameters) => saveMessageMock(...args), + } +}) + +const mockSystemFeatures = { + branding: { + enabled: false, + workspace_logo: null, + }, +} + +const defaultAppInfo = { + app_id: 'app-123', + site: { + title: 'Share title', + description: 'Share description', + default_language: 'en-US', + icon_type: 'emoji', + icon: 'robot', + icon_background: '#fff', + icon_url: '', + }, + custom_config: { + remove_webapp_brand: false, + replace_webapp_logo: '', + }, +} + +type MockAppInfo = Omit & { + custom_config: typeof defaultAppInfo.custom_config | null +} + +const defaultAppParams = { + user_input_form: [ + { + 'text-input': { + label: 'Name', + variable: 'name', + required: true, + max_length: 48, + default: 'Alice', + hide: false, + }, + }, + { + checkbox: { + label: 'Enabled', + variable: 'enabled', + required: false, + default: true, + hide: false, + }, + }, + ], + more_like_this: { + enabled: true, + }, + file_upload: { + enabled: true, + number_limits: 2, + detail: 'low', + allowed_upload_methods: ['local_file'], + }, + text_to_speech: { + enabled: true, + }, + system_parameters: { + image_file_size_limit: 10, + }, +} + +type MockWebAppState = { + appInfo: MockAppInfo | null + appParams: typeof defaultAppParams | null + webAppAccessMode: string +} + +const mockWebAppState: MockWebAppState = { + appInfo: defaultAppInfo, + appParams: defaultAppParams, + webAppAccessMode: 'public', +} + +const resetMockWebAppState = () => { + mockWebAppState.appInfo = { + ...defaultAppInfo, + site: { + ...defaultAppInfo.site, + }, + custom_config: { + ...defaultAppInfo.custom_config, + }, + } + mockWebAppState.appParams = { + ...defaultAppParams, + user_input_form: [...defaultAppParams.user_input_form], + more_like_this: { + enabled: true, + }, + file_upload: { + ...defaultAppParams.file_upload, + allowed_upload_methods: [...defaultAppParams.file_upload.allowed_upload_methods], + }, + text_to_speech: { + ...defaultAppParams.text_to_speech, + }, + system_parameters: { + image_file_size_limit: 10, + }, + } + mockWebAppState.webAppAccessMode = 'public' +} + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: typeof mockSystemFeatures }) => unknown) => + selector({ systemFeatures: mockSystemFeatures }), +})) + +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: typeof mockWebAppState) => unknown) => selector(mockWebAppState), +})) + +describe('useTextGenerationAppState', () => { + beforeEach(() => { + vi.clearAllMocks() + resetMockWebAppState() + fetchSavedMessageMock.mockResolvedValue({ + data: [{ id: 'saved-1' }], + }) + removeMessageMock.mockResolvedValue(undefined) + saveMessageMock.mockResolvedValue(undefined) + }) + + it('should initialize app state and fetch saved messages for non-workflow web apps', async () => { + const { result } = renderHook(() => useTextGenerationAppState({ + isInstalledApp: false, + isWorkflow: false, + })) + + await waitFor(() => { + expect(result.current.appId).toBe('app-123') + expect(result.current.promptConfig?.prompt_variables.map(item => item.name)).toEqual(['Name', 'Enabled']) + expect(result.current.savedMessages).toEqual([{ id: 'saved-1' }]) + }) + + expect(result.current.appSourceType).toBe(AppSourceType.webApp) + expect(result.current.siteInfo?.title).toBe('Share title') + expect(result.current.visionConfig.transfer_methods).toEqual(['local_file']) + expect(result.current.visionConfig.image_file_size_limit).toBe(10) + expect(changeLanguageMock).toHaveBeenCalledWith('en-US') + expect(fetchSavedMessageMock).toHaveBeenCalledWith(AppSourceType.webApp, 'app-123') + expect(useDocumentTitleMock).toHaveBeenCalledWith('Share title') + expect(useAppFaviconMock).toHaveBeenCalledWith(expect.objectContaining({ + enable: true, + icon: 'robot', + })) + }) + + it('should no-op save actions before the app id is initialized', async () => { + mockWebAppState.appInfo = null + mockWebAppState.appParams = null + + const { result } = renderHook(() => useTextGenerationAppState({ + isInstalledApp: false, + isWorkflow: false, + })) + + await act(async () => { + await result.current.fetchSavedMessages('') + await result.current.handleSaveMessage('message-1') + await result.current.handleRemoveSavedMessage('message-1') + }) + + expect(result.current.appId).toBe('') + expect(fetchSavedMessageMock).not.toHaveBeenCalled() + expect(saveMessageMock).not.toHaveBeenCalled() + expect(removeMessageMock).not.toHaveBeenCalled() + expect(notifyMock).not.toHaveBeenCalled() + }) + + it('should fallback to null custom config when the share metadata omits it', async () => { + mockWebAppState.appInfo = { + ...defaultAppInfo, + custom_config: null, + } + + const { result } = renderHook(() => useTextGenerationAppState({ + isInstalledApp: false, + isWorkflow: false, + })) + + await waitFor(() => { + expect(result.current.appId).toBe('app-123') + expect(result.current.customConfig).toBeNull() + }) + }) + + it('should save and remove messages then refresh saved messages', async () => { + const { result } = renderHook(() => useTextGenerationAppState({ + isInstalledApp: false, + isWorkflow: false, + })) + + await waitFor(() => { + expect(result.current.appId).toBe('app-123') + }) + + fetchSavedMessageMock.mockClear() + + await act(async () => { + await result.current.handleSaveMessage('message-1') + }) + + expect(saveMessageMock).toHaveBeenCalledWith('message-1', AppSourceType.webApp, 'app-123') + expect(fetchSavedMessageMock).toHaveBeenCalledWith(AppSourceType.webApp, 'app-123') + expect(notifyMock).toHaveBeenCalledWith({ + type: 'success', + message: 'common.api.saved', + }) + + fetchSavedMessageMock.mockClear() + notifyMock.mockClear() + + await act(async () => { + await result.current.handleRemoveSavedMessage('message-1') + }) + + expect(removeMessageMock).toHaveBeenCalledWith('message-1', AppSourceType.webApp, 'app-123') + expect(fetchSavedMessageMock).toHaveBeenCalledWith(AppSourceType.webApp, 'app-123') + expect(notifyMock).toHaveBeenCalledWith({ + type: 'success', + message: 'common.api.remove', + }) + }) + + it('should skip saved message fetching for workflows and disable favicon for installed apps', async () => { + const { result } = renderHook(() => useTextGenerationAppState({ + isInstalledApp: true, + isWorkflow: true, + })) + + await waitFor(() => { + expect(result.current.appId).toBe('app-123') + }) + + expect(result.current.appSourceType).toBe(AppSourceType.installedApp) + expect(fetchSavedMessageMock).not.toHaveBeenCalled() + expect(useAppFaviconMock).toHaveBeenCalledWith(expect.objectContaining({ + enable: false, + })) + }) +}) diff --git a/web/app/components/share/text-generation/hooks/__tests__/use-text-generation-batch.spec.ts b/web/app/components/share/text-generation/hooks/__tests__/use-text-generation-batch.spec.ts new file mode 100644 index 0000000000..3dab88c578 --- /dev/null +++ b/web/app/components/share/text-generation/hooks/__tests__/use-text-generation-batch.spec.ts @@ -0,0 +1,314 @@ +import type { PromptConfig, PromptVariable } from '@/models/debug' +import { act, renderHook } from '@testing-library/react' +import { BATCH_CONCURRENCY } from '@/config' +import { TaskStatus } from '../../types' +import { useTextGenerationBatch } from '../use-text-generation-batch' + +const createVariable = (overrides: Partial): PromptVariable => ({ + key: 'input', + name: 'Input', + type: 'string', + required: true, + ...overrides, +}) + +const createPromptConfig = (): PromptConfig => ({ + prompt_template: 'template', + prompt_variables: [ + createVariable({ key: 'name', name: 'Name', type: 'string', required: true }), + createVariable({ key: 'score', name: 'Score', type: 'number', required: false }), + ], +}) + +const createTranslator = () => vi.fn((key: string) => key) + +const renderBatchHook = (promptConfig: PromptConfig = createPromptConfig()) => { + const notify = vi.fn() + const onStart = vi.fn() + const t = createTranslator() + + const hook = renderHook(() => useTextGenerationBatch({ + promptConfig, + notify, + t, + })) + + return { + ...hook, + notify, + onStart, + t, + } +} + +describe('useTextGenerationBatch', () => { + it('should initialize the first batch group when csv content is valid', () => { + const { result, onStart } = renderBatchHook() + const csvData = [ + ['Name', 'Score'], + ...Array.from({ length: BATCH_CONCURRENCY + 1 }, (_, index) => [`Item ${index + 1}`, '']), + ] + + let isStarted = false + act(() => { + isStarted = result.current.handleRunBatch(csvData, { onStart }) + }) + + expect(isStarted).toBe(true) + expect(onStart).toHaveBeenCalledTimes(1) + expect(result.current.isCallBatchAPI).toBe(true) + expect(result.current.allTaskList).toHaveLength(BATCH_CONCURRENCY + 1) + expect(result.current.allTaskList.slice(0, BATCH_CONCURRENCY).every(task => task.status === TaskStatus.running)).toBe(true) + expect(result.current.allTaskList.at(-1)?.status).toBe(TaskStatus.pending) + expect(result.current.allTaskList[0]?.params.inputs).toEqual({ + name: 'Item 1', + score: undefined, + }) + }) + + it('should reject csv data when the header does not match prompt variables', () => { + const { result, notify, onStart } = renderBatchHook() + + let isStarted = true + act(() => { + isStarted = result.current.handleRunBatch([ + ['Prompt', 'Score'], + ['Hello', '1'], + ], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(onStart).not.toHaveBeenCalled() + expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'generation.errorMsg.fileStructNotMatch', + }) + expect(result.current.allTaskList).toEqual([]) + }) + + it('should reject empty batch inputs and rows without executable payload', () => { + const { result, notify, onStart } = renderBatchHook() + + let isStarted = true + act(() => { + isStarted = result.current.handleRunBatch([], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(notify).toHaveBeenLastCalledWith({ + type: 'error', + message: 'generation.errorMsg.empty', + }) + + notify.mockClear() + + act(() => { + isStarted = result.current.handleRunBatch([ + ['Name', 'Score'], + ], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(notify).toHaveBeenLastCalledWith({ + type: 'error', + message: 'generation.errorMsg.atLeastOne', + }) + + notify.mockClear() + + act(() => { + isStarted = result.current.handleRunBatch([ + ['Name', 'Score'], + ['', ''], + ], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(notify).toHaveBeenLastCalledWith({ + type: 'error', + message: 'generation.errorMsg.atLeastOne', + }) + }) + + it('should reject csv data when empty rows appear in the middle of the payload', () => { + const { result, notify, onStart } = renderBatchHook() + + let isStarted = true + act(() => { + isStarted = result.current.handleRunBatch([ + ['Name', 'Score'], + ['Alice', '1'], + ['', ''], + ['Bob', '2'], + ['', ''], + ['', ''], + ], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'generation.errorMsg.emptyLine', + }) + }) + + it('should reject rows with missing required values', () => { + const { result, notify, onStart } = renderBatchHook() + + let isStarted = true + act(() => { + isStarted = result.current.handleRunBatch([ + ['Name', 'Score'], + ['', '1'], + ], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'generation.errorMsg.invalidLine', + }) + }) + + it('should reject rows that exceed the configured max length', () => { + const { result, notify, onStart } = renderBatchHook({ + prompt_template: 'template', + prompt_variables: [ + createVariable({ key: 'name', name: 'Name', type: 'string', required: true, max_length: 3 }), + createVariable({ key: 'score', name: 'Score', type: 'number', required: false }), + ], + }) + + let isStarted = true + act(() => { + isStarted = result.current.handleRunBatch([ + ['Name', 'Score'], + ['Alice', '1'], + ], { onStart }) + }) + + expect(isStarted).toBe(false) + expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'generation.errorMsg.moreThanMaxLengthLine', + }) + }) + + it('should promote pending tasks after the current batch group completes', () => { + const { result } = renderBatchHook() + const csvData = [ + ['Name', 'Score'], + ...Array.from({ length: BATCH_CONCURRENCY + 1 }, (_, index) => [`Item ${index + 1}`, `${index + 1}`]), + ] + + act(() => { + result.current.handleRunBatch(csvData, { onStart: vi.fn() }) + }) + + act(() => { + Array.from({ length: BATCH_CONCURRENCY }).forEach((_, index) => { + result.current.handleCompleted(`Result ${index + 1}`, index + 1, true) + }) + }) + + expect(result.current.allTaskList.at(-1)?.status).toBe(TaskStatus.running) + expect(result.current.exportRes.at(0)).toEqual({ + 'Name': 'Item 1', + 'Score': '1', + 'generation.completionResult': 'Result 1', + }) + }) + + it('should block starting a new batch while previous tasks are still running', () => { + const { result, notify, onStart } = renderBatchHook() + const csvData = [ + ['Name', 'Score'], + ...Array.from({ length: BATCH_CONCURRENCY + 1 }, (_, index) => [`Item ${index + 1}`, `${index + 1}`]), + ] + + act(() => { + result.current.handleRunBatch(csvData, { onStart }) + }) + + notify.mockClear() + + let isStarted = true + act(() => { + isStarted = result.current.handleRunBatch(csvData, { onStart }) + }) + + expect(isStarted).toBe(false) + expect(onStart).toHaveBeenCalledTimes(1) + expect(notify).toHaveBeenCalledWith({ + type: 'info', + message: 'errorMessage.waitForBatchResponse', + }) + }) + + it('should ignore completion updates without a task id', () => { + const { result } = renderBatchHook() + + act(() => { + result.current.handleRunBatch([ + ['Name', 'Score'], + ['Alice', '1'], + ], { onStart: vi.fn() }) + }) + + const taskSnapshot = result.current.allTaskList + + act(() => { + result.current.handleCompleted('ignored') + }) + + expect(result.current.allTaskList).toEqual(taskSnapshot) + }) + + it('should expose failed tasks, retry signals, and reset state after batch failures', () => { + const { result } = renderBatchHook() + + act(() => { + result.current.handleRunBatch([ + ['Name', 'Score'], + ['Alice', ''], + ], { onStart: vi.fn() }) + }) + + act(() => { + result.current.handleCompleted({ answer: 'failed' } as unknown as string, 1, false) + }) + + expect(result.current.allFailedTaskList).toEqual([ + expect.objectContaining({ + id: 1, + status: TaskStatus.failed, + }), + ]) + expect(result.current.allTasksFinished).toBe(false) + expect(result.current.allTasksRun).toBe(true) + expect(result.current.noPendingTask).toBe(true) + expect(result.current.exportRes).toEqual([ + { + 'Name': 'Alice', + 'Score': '', + 'generation.completionResult': JSON.stringify({ answer: 'failed' }), + }, + ]) + + act(() => { + result.current.handleRetryAllFailedTask() + }) + + expect(result.current.controlRetry).toBeGreaterThan(0) + + act(() => { + result.current.resetBatchExecution() + }) + + expect(result.current.allTaskList).toEqual([]) + expect(result.current.allFailedTaskList).toEqual([]) + expect(result.current.showTaskList).toEqual([]) + expect(result.current.exportRes).toEqual([]) + expect(result.current.noPendingTask).toBe(true) + }) +}) diff --git a/web/app/components/share/text-generation/hooks/use-text-generation-app-state.ts b/web/app/components/share/text-generation/hooks/use-text-generation-app-state.ts new file mode 100644 index 0000000000..83fca4d0d6 --- /dev/null +++ b/web/app/components/share/text-generation/hooks/use-text-generation-app-state.ts @@ -0,0 +1,158 @@ +import type { TextGenerationCustomConfig } from '../types' +import type { + MoreLikeThisConfig, + PromptConfig, + SavedMessage, + TextToSpeechConfig, +} from '@/models/debug' +import type { SiteInfo } from '@/models/share' +import type { VisionSettings } from '@/types/app' +import { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { useWebAppStore } from '@/context/web-app-context' +import { useAppFavicon } from '@/hooks/use-app-favicon' +import useDocumentTitle from '@/hooks/use-document-title' +import { changeLanguage } from '@/i18n-config/client' +import { AppSourceType, fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share' +import { Resolution, TransferMethod } from '@/types/app' +import { userInputsFormToPromptVariables } from '@/utils/model-config' + +type UseTextGenerationAppStateOptions = { + isInstalledApp: boolean + isWorkflow: boolean +} + +type ShareAppParams = { + user_input_form: Parameters[0] + more_like_this: MoreLikeThisConfig | null + file_upload: VisionSettings & { + allowed_file_upload_methods?: TransferMethod[] + allowed_upload_methods?: TransferMethod[] + } + text_to_speech: TextToSpeechConfig | null + system_parameters?: Record & { + image_file_size_limit?: number + } +} + +export const useTextGenerationAppState = ({ + isInstalledApp, + isWorkflow, +}: UseTextGenerationAppStateOptions) => { + const { notify } = Toast + const { t } = useTranslation() + const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp + const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + const appData = useWebAppStore(s => s.appInfo) + const appParams = useWebAppStore(s => s.appParams) + const accessMode = useWebAppStore(s => s.webAppAccessMode) + + const [appId, setAppId] = useState('') + const [siteInfo, setSiteInfo] = useState(null) + const [customConfig, setCustomConfig] = useState(null) + const [promptConfig, setPromptConfig] = useState(null) + const [moreLikeThisConfig, setMoreLikeThisConfig] = useState(null) + const [textToSpeechConfig, setTextToSpeechConfig] = useState(null) + const [savedMessages, setSavedMessages] = useState([]) + const [visionConfig, setVisionConfig] = useState({ + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], + }) + + const fetchSavedMessages = useCallback(async (targetAppId = appId) => { + if (!targetAppId) + return + const res = await doFetchSavedMessage(appSourceType, targetAppId) as { data: SavedMessage[] } + setSavedMessages(res.data) + }, [appId, appSourceType]) + + const handleSaveMessage = useCallback(async (messageId: string) => { + if (!appId) + return + await saveMessage(messageId, appSourceType, appId) + notify({ type: 'success', message: t('api.saved', { ns: 'common' }) }) + await fetchSavedMessages(appId) + }, [appId, appSourceType, fetchSavedMessages, notify, t]) + + const handleRemoveSavedMessage = useCallback(async (messageId: string) => { + if (!appId) + return + await removeMessage(messageId, appSourceType, appId) + notify({ type: 'success', message: t('api.remove', { ns: 'common' }) }) + await fetchSavedMessages(appId) + }, [appId, appSourceType, fetchSavedMessages, notify, t]) + + useEffect(() => { + let cancelled = false + + const initialize = async () => { + if (!appData || !appParams) + return + + const { app_id: nextAppId, site, custom_config } = appData + + setAppId(nextAppId) + setSiteInfo(site as SiteInfo) + setCustomConfig((custom_config || null) as TextGenerationCustomConfig | null) + await changeLanguage(site.default_language) + + const { user_input_form, more_like_this, file_upload, text_to_speech } = appParams as unknown as ShareAppParams + if (cancelled) + return + + setVisionConfig({ + ...file_upload, + transfer_methods: file_upload?.allowed_file_upload_methods || file_upload?.allowed_upload_methods, + image_file_size_limit: appParams?.system_parameters.image_file_size_limit, + fileUploadConfig: appParams?.system_parameters, + } as VisionSettings) + setPromptConfig({ + prompt_template: '', + prompt_variables: userInputsFormToPromptVariables(user_input_form), + } as PromptConfig) + setMoreLikeThisConfig(more_like_this) + setTextToSpeechConfig(text_to_speech) + + if (!isWorkflow) + await fetchSavedMessages(nextAppId) + } + + void initialize() + + return () => { + cancelled = true + } + }, [appData, appParams, fetchSavedMessages, isWorkflow]) + + useDocumentTitle(siteInfo?.title || t('generation.title', { ns: 'share' })) + + useAppFavicon({ + enable: !isInstalledApp, + icon_type: siteInfo?.icon_type, + icon: siteInfo?.icon, + icon_background: siteInfo?.icon_background, + icon_url: siteInfo?.icon_url, + }) + + return { + accessMode, + appId, + appSourceType, + customConfig, + fetchSavedMessages, + handleRemoveSavedMessage, + handleSaveMessage, + moreLikeThisConfig, + promptConfig, + savedMessages, + siteInfo, + systemFeatures, + textToSpeechConfig, + visionConfig, + setVisionConfig, + } +} diff --git a/web/app/components/share/text-generation/hooks/use-text-generation-batch.ts b/web/app/components/share/text-generation/hooks/use-text-generation-batch.ts new file mode 100644 index 0000000000..522d2e4681 --- /dev/null +++ b/web/app/components/share/text-generation/hooks/use-text-generation-batch.ts @@ -0,0 +1,270 @@ +import type { Task } from '../types' +import type { PromptConfig } from '@/models/debug' +import { useCallback, useMemo, useRef, useState } from 'react' +import { BATCH_CONCURRENCY } from '@/config' +import { TaskStatus } from '../types' + +type BatchNotify = (payload: { type: 'error' | 'info', message: string }) => void +type BatchTranslate = (key: string, options?: Record) => string + +type UseTextGenerationBatchOptions = { + promptConfig: PromptConfig | null + notify: BatchNotify + t: BatchTranslate +} + +type RunBatchCallbacks = { + onStart: () => void +} + +const GROUP_SIZE = BATCH_CONCURRENCY + +export const useTextGenerationBatch = ({ + promptConfig, + notify, + t, +}: UseTextGenerationBatchOptions) => { + const [isCallBatchAPI, setIsCallBatchAPI] = useState(false) + const [controlRetry, setControlRetry] = useState(0) + const [allTaskList, setAllTaskList] = useState([]) + const [batchCompletionMap, setBatchCompletionMap] = useState>({}) + const allTaskListRef = useRef([]) + const currGroupNumRef = useRef(0) + const batchCompletionResRef = useRef>({}) + + const updateAllTaskList = useCallback((taskList: Task[]) => { + setAllTaskList(taskList) + allTaskListRef.current = taskList + }, []) + + const updateBatchCompletionRes = useCallback((res: Record) => { + batchCompletionResRef.current = res + setBatchCompletionMap(res) + }, []) + + const resetBatchExecution = useCallback(() => { + updateAllTaskList([]) + updateBatchCompletionRes({}) + currGroupNumRef.current = 0 + }, [updateAllTaskList, updateBatchCompletionRes]) + + const checkBatchInputs = useCallback((data: string[][]) => { + if (!data || data.length === 0) { + notify({ type: 'error', message: t('generation.errorMsg.empty', { ns: 'share' }) }) + return false + } + + const promptVariables = promptConfig?.prompt_variables ?? [] + const headerData = data[0] + let isMapVarName = true + promptVariables.forEach((item, index) => { + if (!isMapVarName) + return + + if (item.name !== headerData[index]) + isMapVarName = false + }) + + if (!isMapVarName) { + notify({ type: 'error', message: t('generation.errorMsg.fileStructNotMatch', { ns: 'share' }) }) + return false + } + + let payloadData = data.slice(1) + if (payloadData.length === 0) { + notify({ type: 'error', message: t('generation.errorMsg.atLeastOne', { ns: 'share' }) }) + return false + } + + const emptyLineIndexes = payloadData + .filter(item => item.every(value => value === '')) + .map(item => payloadData.indexOf(item)) + if (emptyLineIndexes.length > 0) { + let hasMiddleEmptyLine = false + let startIndex = emptyLineIndexes[0] - 1 + emptyLineIndexes.forEach((index) => { + if (hasMiddleEmptyLine) + return + if (startIndex + 1 !== index) { + hasMiddleEmptyLine = true + return + } + startIndex += 1 + }) + + if (hasMiddleEmptyLine) { + notify({ type: 'error', message: t('generation.errorMsg.emptyLine', { ns: 'share', rowIndex: startIndex + 2 }) }) + return false + } + } + + payloadData = payloadData.filter(item => !item.every(value => value === '')) + if (payloadData.length === 0) { + notify({ type: 'error', message: t('generation.errorMsg.atLeastOne', { ns: 'share' }) }) + return false + } + + let errorRowIndex = 0 + let requiredVarName = '' + let tooLongVarName = '' + let maxLength = 0 + + for (const [index, item] of payloadData.entries()) { + for (const [varIndex, varItem] of promptVariables.entries()) { + const value = item[varIndex] ?? '' + + if (varItem.type === 'string' && varItem.max_length && value.length > varItem.max_length) { + tooLongVarName = varItem.name + maxLength = varItem.max_length + errorRowIndex = index + 1 + break + } + + if (varItem.required && value.trim() === '') { + requiredVarName = varItem.name + errorRowIndex = index + 1 + break + } + } + + if (errorRowIndex !== 0) + break + } + + if (errorRowIndex !== 0) { + if (requiredVarName) { + notify({ + type: 'error', + message: t('generation.errorMsg.invalidLine', { ns: 'share', rowIndex: errorRowIndex + 1, varName: requiredVarName }), + }) + } + + if (tooLongVarName) { + notify({ + type: 'error', + message: t('generation.errorMsg.moreThanMaxLengthLine', { + ns: 'share', + rowIndex: errorRowIndex + 1, + varName: tooLongVarName, + maxLength, + }), + }) + } + + return false + } + + return true + }, [notify, promptConfig, t]) + + const handleRunBatch = useCallback((data: string[][], { onStart }: RunBatchCallbacks) => { + if (!checkBatchInputs(data)) + return false + + const latestTaskList = allTaskListRef.current + const allTasksFinished = latestTaskList.every(task => task.status === TaskStatus.completed) + if (!allTasksFinished && latestTaskList.length > 0) { + notify({ type: 'info', message: t('errorMessage.waitForBatchResponse', { ns: 'appDebug' }) }) + return false + } + + const payloadData = data.filter(item => !item.every(value => value === '')).slice(1) + const promptVariables = promptConfig?.prompt_variables ?? [] + const nextTaskList: Task[] = payloadData.map((item, index) => { + const inputs: Record = {} + promptVariables.forEach((variable, varIndex) => { + const input = item[varIndex] + inputs[variable.key] = input + if (!input) + inputs[variable.key] = variable.type === 'string' || variable.type === 'paragraph' ? '' : undefined + }) + + return { + id: index + 1, + status: index < GROUP_SIZE ? TaskStatus.running : TaskStatus.pending, + params: { inputs }, + } + }) + + setIsCallBatchAPI(true) + updateAllTaskList(nextTaskList) + updateBatchCompletionRes({}) + currGroupNumRef.current = 0 + onStart() + return true + }, [checkBatchInputs, notify, promptConfig, t, updateAllTaskList, updateBatchCompletionRes]) + + const handleCompleted = useCallback((completionRes: string, taskId?: number, isSuccess?: boolean) => { + if (!taskId) + return + + const latestTaskList = allTaskListRef.current + const latestBatchCompletionRes = batchCompletionResRef.current + const pendingTaskList = latestTaskList.filter(task => task.status === TaskStatus.pending) + const runTasksCount = 1 + latestTaskList.filter(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)).length + const shouldStartNextGroup = currGroupNumRef.current !== runTasksCount + && pendingTaskList.length > 0 + && (runTasksCount % GROUP_SIZE === 0 || (latestTaskList.length - runTasksCount < GROUP_SIZE)) + + if (shouldStartNextGroup) + currGroupNumRef.current = runTasksCount + + const nextPendingTaskIds = shouldStartNextGroup ? pendingTaskList.slice(0, GROUP_SIZE).map(item => item.id) : [] + updateAllTaskList(latestTaskList.map((task) => { + if (task.id === taskId) + return { ...task, status: isSuccess ? TaskStatus.completed : TaskStatus.failed } + if (shouldStartNextGroup && nextPendingTaskIds.includes(task.id)) + return { ...task, status: TaskStatus.running } + return task + })) + updateBatchCompletionRes({ + ...latestBatchCompletionRes, + [taskId]: completionRes, + }) + }, [updateAllTaskList, updateBatchCompletionRes]) + + const handleRetryAllFailedTask = useCallback(() => { + setControlRetry(Date.now()) + }, []) + + const pendingTaskList = allTaskList.filter(task => task.status === TaskStatus.pending) + const showTaskList = allTaskList.filter(task => task.status !== TaskStatus.pending) + const allSuccessTaskList = allTaskList.filter(task => task.status === TaskStatus.completed) + const allFailedTaskList = allTaskList.filter(task => task.status === TaskStatus.failed) + const allTasksFinished = allTaskList.every(task => task.status === TaskStatus.completed) + const allTasksRun = allTaskList.every(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)) + + const exportRes = useMemo(() => { + return allTaskList.map((task) => { + const result: Record = {} + promptConfig?.prompt_variables.forEach((variable) => { + result[variable.name] = String(task.params.inputs[variable.key] ?? '') + }) + + let completionValue = batchCompletionMap[String(task.id)] + if (typeof completionValue === 'object') + completionValue = JSON.stringify(completionValue) + + result[t('generation.completionResult', { ns: 'share' })] = completionValue + return result + }) + }, [allTaskList, batchCompletionMap, promptConfig, t]) + + return { + allFailedTaskList, + allSuccessTaskList, + allTaskList, + allTasksFinished, + allTasksRun, + controlRetry, + exportRes, + handleCompleted, + handleRetryAllFailedTask, + handleRunBatch, + isCallBatchAPI, + noPendingTask: pendingTaskList.length === 0, + resetBatchExecution, + setIsCallBatchAPI, + showTaskList, + } +} diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index 90a2fb9277..779358bfc6 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -1,65 +1,20 @@ 'use client' import type { FC } from 'react' -import type { - MoreLikeThisConfig, - PromptConfig, - SavedMessage, - TextToSpeechConfig, -} from '@/models/debug' +import type { InputValueTypes, TextGenerationRunControl } from './types' import type { InstalledApp } from '@/models/explore' -import type { SiteInfo } from '@/models/share' -import type { VisionFile, VisionSettings } from '@/types/app' -import { - RiBookmark3Line, - RiErrorWarningFill, -} from '@remixicon/react' +import type { VisionFile } from '@/types/app' import { useBoolean } from 'ahooks' import { useSearchParams } from 'next/navigation' -import * as React from 'react' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import SavedItems from '@/app/components/app/text-generate/saved-items' -import AppIcon from '@/app/components/base/app-icon' -import Badge from '@/app/components/base/badge' import Loading from '@/app/components/base/loading' -import DifyLogo from '@/app/components/base/logo/dify-logo' import Toast from '@/app/components/base/toast' -import Res from '@/app/components/share/text-generation/result' -import RunOnce from '@/app/components/share/text-generation/run-once' -import { appDefaultIconBackground, BATCH_CONCURRENCY } from '@/config' -import { useGlobalPublicStore } from '@/context/global-public-context' -import { useWebAppStore } from '@/context/web-app-context' -import { useAppFavicon } from '@/hooks/use-app-favicon' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import useDocumentTitle from '@/hooks/use-document-title' -import { changeLanguage } from '@/i18n-config/client' -import { AccessMode } from '@/models/access-control' -import { AppSourceType, fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share' -import { Resolution, TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' -import { userInputsFormToPromptVariables } from '@/utils/model-config' -import TabHeader from '../../base/tab-header' -import MenuDropdown from './menu-dropdown' -import RunBatch from './run-batch' -import ResDownload from './run-batch/res-download' - -const GROUP_SIZE = BATCH_CONCURRENCY // to avoid RPM(Request per minute) limit. The group task finished then the next group. -enum TaskStatus { - pending = 'pending', - running = 'running', - completed = 'completed', - failed = 'failed', -} - -type TaskParam = { - inputs: Record -} - -type Task = { - id: number - status: TaskStatus - params: TaskParam -} +import { useTextGenerationAppState } from './hooks/use-text-generation-app-state' +import { useTextGenerationBatch } from './hooks/use-text-generation-batch' +import TextGenerationResultPanel from './text-generation-result-panel' +import TextGenerationSidebar from './text-generation-sidebar' export type IMainProps = { isInstalledApp?: boolean @@ -72,8 +27,6 @@ const TextGeneration: FC = ({ isWorkflow = false, }) => { const { notify } = Toast - const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp - const { t } = useTranslation() const media = useBreakpoints() const isPC = media === MediaType.pc @@ -81,428 +34,90 @@ const TextGeneration: FC = ({ const searchParams = useSearchParams() const mode = searchParams.get('mode') || 'create' const [currentTab, setCurrentTab] = useState(['create', 'batch'].includes(mode) ? mode : 'create') - - // Notice this situation isCallBatchAPI but not in batch tab - const [isCallBatchAPI, setIsCallBatchAPI] = useState(false) - const isInBatchTab = currentTab === 'batch' - const [inputs, doSetInputs] = useState>({}) + const [inputs, setInputs] = useState>({}) const inputsRef = useRef(inputs) - const setInputs = useCallback((newInputs: Record) => { - doSetInputs(newInputs) - inputsRef.current = newInputs - }, []) - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - const [appId, setAppId] = useState('') - const [siteInfo, setSiteInfo] = useState(null) - const [customConfig, setCustomConfig] = useState | null>(null) - const [promptConfig, setPromptConfig] = useState(null) - const [moreLikeThisConfig, setMoreLikeThisConfig] = useState(null) - const [textToSpeechConfig, setTextToSpeechConfig] = useState(null) - - // save message - const [savedMessages, setSavedMessages] = useState([]) - const fetchSavedMessage = useCallback(async () => { - if (!appId) - return - const res: any = await doFetchSavedMessage(appSourceType, appId) - setSavedMessages(res.data) - }, [appSourceType, appId]) - const handleSaveMessage = async (messageId: string) => { - await saveMessage(messageId, appSourceType, appId) - notify({ type: 'success', message: t('api.saved', { ns: 'common' }) }) - fetchSavedMessage() - } - const handleRemoveSavedMessage = async (messageId: string) => { - await removeMessage(messageId, appSourceType, appId) - notify({ type: 'success', message: t('api.remove', { ns: 'common' }) }) - fetchSavedMessage() - } - - // send message task + const [completionFiles, setCompletionFiles] = useState([]) + const [runControl, setRunControl] = useState(null) const [controlSend, setControlSend] = useState(0) const [controlStopResponding, setControlStopResponding] = useState(0) - const [visionConfig, setVisionConfig] = useState({ - enabled: false, - number_limits: 2, - detail: Resolution.low, - transfer_methods: [TransferMethod.local_file], + const [resultExisted, setResultExisted] = useState(false) + const [isShowResultPanel, { setTrue: showResultPanelState, setFalse: hideResultPanel }] = useBoolean(false) + + const updateInputs = useCallback((newInputs: Record) => { + setInputs(newInputs) + inputsRef.current = newInputs + }, []) + + const { + accessMode, + appId, + appSourceType, + customConfig, + handleRemoveSavedMessage, + handleSaveMessage, + moreLikeThisConfig, + promptConfig, + savedMessages, + siteInfo, + systemFeatures, + textToSpeechConfig, + visionConfig, + } = useTextGenerationAppState({ + isInstalledApp, + isWorkflow, + }) + + const { + allFailedTaskList, + allSuccessTaskList, + allTaskList, + allTasksRun, + controlRetry, + exportRes, + handleCompleted, + handleRetryAllFailedTask, + handleRunBatch: runBatchExecution, + isCallBatchAPI, + noPendingTask, + resetBatchExecution, + setIsCallBatchAPI, + showTaskList, + } = useTextGenerationBatch({ + promptConfig, + notify, + t, }) - const [completionFiles, setCompletionFiles] = useState([]) - const [runControl, setRunControl] = useState<{ onStop: () => Promise | void, isStopping: boolean } | null>(null) useEffect(() => { if (isCallBatchAPI) setRunControl(null) }, [isCallBatchAPI]) - const handleSend = () => { + const showResultPanel = useCallback(() => { + setTimeout(() => { + showResultPanelState() + }, 0) + }, [showResultPanelState]) + const handleRunStart = useCallback(() => { + setResultExisted(true) + }, []) + + const handleRunOnce = useCallback(() => { setIsCallBatchAPI(false) setControlSend(Date.now()) - - // eslint-disable-next-line ts/no-use-before-define - setAllTaskList([]) // clear batch task running status - - // eslint-disable-next-line ts/no-use-before-define + resetBatchExecution() showResultPanel() - } + }, [resetBatchExecution, setIsCallBatchAPI, showResultPanel]) - const [controlRetry, setControlRetry] = useState(0) - const handleRetryAllFailedTask = () => { - setControlRetry(Date.now()) - } - const [allTaskList, doSetAllTaskList] = useState([]) - const allTaskListRef = useRef([]) - const getLatestTaskList = () => allTaskListRef.current - const setAllTaskList = (taskList: Task[]) => { - doSetAllTaskList(taskList) - allTaskListRef.current = taskList - } - const pendingTaskList = allTaskList.filter(task => task.status === TaskStatus.pending) - const noPendingTask = pendingTaskList.length === 0 - const showTaskList = allTaskList.filter(task => task.status !== TaskStatus.pending) - const currGroupNumRef = useRef(0) - - const setCurrGroupNum = (num: number) => { - currGroupNumRef.current = num - } - const getCurrGroupNum = () => { - return currGroupNumRef.current - } - const allSuccessTaskList = allTaskList.filter(task => task.status === TaskStatus.completed) - const allFailedTaskList = allTaskList.filter(task => task.status === TaskStatus.failed) - const allTasksFinished = allTaskList.every(task => task.status === TaskStatus.completed) - const allTasksRun = allTaskList.every(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)) - const batchCompletionResRef = useRef>({}) - const setBatchCompletionRes = (res: Record) => { - batchCompletionResRef.current = res - } - const getBatchCompletionRes = () => batchCompletionResRef.current - const exportRes = allTaskList.map((task) => { - const batchCompletionResLatest = getBatchCompletionRes() - const res: Record = {} - const { inputs } = task.params - promptConfig?.prompt_variables.forEach((v) => { - res[v.name] = inputs[v.key] + const handleRunBatch = useCallback((data: string[][]) => { + runBatchExecution(data, { + onStart: () => { + setControlSend(Date.now()) + setControlStopResponding(Date.now()) + showResultPanel() + }, }) - let result = batchCompletionResLatest[task.id] - // task might return multiple fields, should marshal object to string - if (typeof batchCompletionResLatest[task.id] === 'object') - result = JSON.stringify(result) - - res[t('generation.completionResult', { ns: 'share' })] = result - return res - }) - const checkBatchInputs = (data: string[][]) => { - if (!data || data.length === 0) { - notify({ type: 'error', message: t('generation.errorMsg.empty', { ns: 'share' }) }) - return false - } - const headerData = data[0] - let isMapVarName = true - promptConfig?.prompt_variables.forEach((item, index) => { - if (!isMapVarName) - return - - if (item.name !== headerData[index]) - isMapVarName = false - }) - - if (!isMapVarName) { - notify({ type: 'error', message: t('generation.errorMsg.fileStructNotMatch', { ns: 'share' }) }) - return false - } - - let payloadData = data.slice(1) - if (payloadData.length === 0) { - notify({ type: 'error', message: t('generation.errorMsg.atLeastOne', { ns: 'share' }) }) - return false - } - - // check middle empty line - const allEmptyLineIndexes = payloadData.filter(item => item.every(i => i === '')).map(item => payloadData.indexOf(item)) - if (allEmptyLineIndexes.length > 0) { - let hasMiddleEmptyLine = false - let startIndex = allEmptyLineIndexes[0] - 1 - allEmptyLineIndexes.forEach((index) => { - if (hasMiddleEmptyLine) - return - - if (startIndex + 1 !== index) { - hasMiddleEmptyLine = true - return - } - startIndex++ - }) - - if (hasMiddleEmptyLine) { - notify({ type: 'error', message: t('generation.errorMsg.emptyLine', { ns: 'share', rowIndex: startIndex + 2 }) }) - return false - } - } - - // check row format - payloadData = payloadData.filter(item => !item.every(i => i === '')) - // after remove empty rows in the end, checked again - if (payloadData.length === 0) { - notify({ type: 'error', message: t('generation.errorMsg.atLeastOne', { ns: 'share' }) }) - return false - } - let errorRowIndex = 0 - let requiredVarName = '' - let moreThanMaxLengthVarName = '' - let maxLength = 0 - payloadData.forEach((item, index) => { - if (errorRowIndex !== 0) - return - - promptConfig?.prompt_variables.forEach((varItem, varIndex) => { - if (errorRowIndex !== 0) - return - if (varItem.type === 'string' && varItem.max_length) { - if (item[varIndex].length > varItem.max_length) { - moreThanMaxLengthVarName = varItem.name - maxLength = varItem.max_length - errorRowIndex = index + 1 - return - } - } - if (!varItem.required) - return - - if (item[varIndex].trim() === '') { - requiredVarName = varItem.name - errorRowIndex = index + 1 - } - }) - }) - - if (errorRowIndex !== 0) { - if (requiredVarName) - notify({ type: 'error', message: t('generation.errorMsg.invalidLine', { ns: 'share', rowIndex: errorRowIndex + 1, varName: requiredVarName }) }) - - if (moreThanMaxLengthVarName) - notify({ type: 'error', message: t('generation.errorMsg.moreThanMaxLengthLine', { ns: 'share', rowIndex: errorRowIndex + 1, varName: moreThanMaxLengthVarName, maxLength }) }) - - return false - } - return true - } - const handleRunBatch = (data: string[][]) => { - if (!checkBatchInputs(data)) - return - if (!allTasksFinished) { - notify({ type: 'info', message: t('errorMessage.waitForBatchResponse', { ns: 'appDebug' }) }) - return - } - - const payloadData = data.filter(item => !item.every(i => i === '')).slice(1) - const varLen = promptConfig?.prompt_variables.length || 0 - setIsCallBatchAPI(true) - const allTaskList: Task[] = payloadData.map((item, i) => { - const inputs: Record = {} - if (varLen > 0) { - item.slice(0, varLen).forEach((input, index) => { - const varSchema = promptConfig?.prompt_variables[index] - inputs[varSchema?.key as string] = input - if (!input) { - if (varSchema?.type === 'string' || varSchema?.type === 'paragraph') - inputs[varSchema?.key as string] = '' - else - inputs[varSchema?.key as string] = undefined - } - }) - } - return { - id: i + 1, - status: i < GROUP_SIZE ? TaskStatus.running : TaskStatus.pending, - params: { - inputs, - }, - } - }) - setAllTaskList(allTaskList) - setCurrGroupNum(0) - setControlSend(Date.now()) - // clear run once task status - setControlStopResponding(Date.now()) - - // eslint-disable-next-line ts/no-use-before-define - showResultPanel() - } - const handleCompleted = (completionRes: string, taskId?: number, isSuccess?: boolean) => { - const allTaskListLatest = getLatestTaskList() - const batchCompletionResLatest = getBatchCompletionRes() - const pendingTaskList = allTaskListLatest.filter(task => task.status === TaskStatus.pending) - const runTasksCount = 1 + allTaskListLatest.filter(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)).length - const needToAddNextGroupTask = (getCurrGroupNum() !== runTasksCount) && pendingTaskList.length > 0 && (runTasksCount % GROUP_SIZE === 0 || (allTaskListLatest.length - runTasksCount < GROUP_SIZE)) - // avoid add many task at the same time - if (needToAddNextGroupTask) - setCurrGroupNum(runTasksCount) - - const nextPendingTaskIds = needToAddNextGroupTask ? pendingTaskList.slice(0, GROUP_SIZE).map(item => item.id) : [] - const newAllTaskList = allTaskListLatest.map((item) => { - if (item.id === taskId) { - return { - ...item, - status: isSuccess ? TaskStatus.completed : TaskStatus.failed, - } - } - if (needToAddNextGroupTask && nextPendingTaskIds.includes(item.id)) { - return { - ...item, - status: TaskStatus.running, - } - } - return item - }) - setAllTaskList(newAllTaskList) - if (taskId) { - setBatchCompletionRes({ - ...batchCompletionResLatest, - [`${taskId}`]: completionRes, - }) - } - } - - const appData = useWebAppStore(s => s.appInfo) - const appParams = useWebAppStore(s => s.appParams) - const accessMode = useWebAppStore(s => s.webAppAccessMode) - useEffect(() => { - (async () => { - if (!appData || !appParams) - return - if (!isWorkflow) - fetchSavedMessage() - const { app_id: appId, site: siteInfo, custom_config } = appData - setAppId(appId) - setSiteInfo(siteInfo as SiteInfo) - setCustomConfig(custom_config) - await changeLanguage(siteInfo.default_language) - - const { user_input_form, more_like_this, file_upload, text_to_speech }: any = appParams - setVisionConfig({ - // legacy of image upload compatible - ...file_upload, - transfer_methods: file_upload?.allowed_file_upload_methods || file_upload?.allowed_upload_methods, - // legacy of image upload compatible - image_file_size_limit: appParams?.system_parameters.image_file_size_limit, - fileUploadConfig: appParams?.system_parameters, - } as any) - const prompt_variables = userInputsFormToPromptVariables(user_input_form) - setPromptConfig({ - prompt_template: '', // placeholder for future - prompt_variables, - } as PromptConfig) - setMoreLikeThisConfig(more_like_this) - setTextToSpeechConfig(text_to_speech) - })() - }, [appData, appParams, fetchSavedMessage, isWorkflow]) - - // Can Use metadata(https://beta.nextjs.org/docs/api-reference/metadata) to set title. But it only works in server side client. - useDocumentTitle(siteInfo?.title || t('generation.title', { ns: 'share' })) - - useAppFavicon({ - enable: !isInstalledApp, - icon_type: siteInfo?.icon_type, - icon: siteInfo?.icon, - icon_background: siteInfo?.icon_background, - icon_url: siteInfo?.icon_url, - }) - - const [isShowResultPanel, { setTrue: doShowResultPanel, setFalse: hideResultPanel }] = useBoolean(false) - const showResultPanel = () => { - // fix: useClickAway hideResSidebar will close sidebar - setTimeout(() => { - doShowResultPanel() - }, 0) - } - const [resultExisted, setResultExisted] = useState(false) - - const renderRes = (task?: Task) => ( - setResultExisted(true)} - onRunControlChange={!isCallBatchAPI ? setRunControl : undefined} - hideInlineStopButton={!isCallBatchAPI} - /> - ) - - const renderBatchRes = () => { - return (showTaskList.map(task => renderRes(task))) - } - - const renderResWrap = ( -
- {isCallBatchAPI && ( -
-
{t('generation.executions', { ns: 'share', num: allTaskList.length })}
- {allSuccessTaskList.length > 0 && ( - - )} -
- )} -
- {!isCallBatchAPI ? renderRes() : renderBatchRes()} - {!noPendingTask && ( -
- -
- )} -
- {isCallBatchAPI && allFailedTaskList.length > 0 && ( -
- -
{t('generation.batchFailed.info', { ns: 'share', num: allFailedTaskList.length })}
-
-
{t('generation.batchFailed.retry', { ns: 'share' })}
-
- )} -
- ) + }, [runBatchExecution, showResultPanel]) if (!appId || !siteInfo || !promptConfig) { return ( @@ -511,147 +126,72 @@ const TextGeneration: FC = ({
) } + return ( -
- {/* Left */} -
- {/* header */} -
-
- -
{siteInfo.title}
- -
- {siteInfo.description && ( -
{siteInfo.description}
- )} - , - extra: savedMessages.length > 0 - ? ( - - {savedMessages.length} - - ) - : null, - }] - : []), - ]} - value={currentTab} - onChange={setCurrentTab} - /> -
- {/* form */} -
-
- -
-
- -
- {currentTab === 'saved' && ( - setCurrentTab('create')} - /> - )} -
- {/* powered by */} - {!customConfig?.remove_webapp_brand && ( -
-
{t('chat.poweredBy', { ns: 'share' })}
- { - systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo - ? logo - : customConfig?.replace_webapp_logo - ? logo - : - } -
- )} -
- {/* Result */} -
- {!isPC && ( -
{ - if (isShowResultPanel) - hideResultPanel() - else - showResultPanel() - }} - > -
-
- )} - {renderResWrap} -
+ +
) } diff --git a/web/app/components/share/text-generation/run-once/index.tsx b/web/app/components/share/text-generation/run-once/index.tsx index 4531ff8beb..f65de4e025 100644 --- a/web/app/components/share/text-generation/run-once/index.tsx +++ b/web/app/components/share/text-generation/run-once/index.tsx @@ -1,5 +1,6 @@ import type { ChangeEvent, FC, FormEvent } from 'react' import type { InputValueTypes } from '../types' +import type { FileEntity } from '@/app/components/base/file-uploader/types' import type { PromptConfig } from '@/models/debug' import type { SiteInfo } from '@/models/share' import type { VisionFile, VisionSettings } from '@/types/app' @@ -169,7 +170,9 @@ const RunOnce: FC = ({ )} {item.type === 'file' && ( { handleInputsChange({ ...inputsRef.current, [item.key]: files[0] }) }} fileConfig={{ ...item.config, @@ -179,7 +182,7 @@ const RunOnce: FC = ({ )} {item.type === 'file-list' && ( { handleInputsChange({ ...inputsRef.current, [item.key]: files }) }} fileConfig={{ ...item.config, diff --git a/web/app/components/share/text-generation/text-generation-result-panel.tsx b/web/app/components/share/text-generation/text-generation-result-panel.tsx new file mode 100644 index 0000000000..a47ad2858c --- /dev/null +++ b/web/app/components/share/text-generation/text-generation-result-panel.tsx @@ -0,0 +1,195 @@ +import type { FC } from 'react' +import type { InputValueTypes, Task, TextGenerationRunControl } from './types' +import type { PromptConfig } from '@/models/debug' +import type { SiteInfo } from '@/models/share' +import type { AppSourceType } from '@/service/share' +import type { VisionFile, VisionSettings } from '@/types/app' +import { useTranslation } from 'react-i18next' +import Loading from '@/app/components/base/loading' +import Res from '@/app/components/share/text-generation/result' +import { cn } from '@/utils/classnames' +import ResDownload from './run-batch/res-download' +import { TaskStatus } from './types' + +type TextGenerationResultPanelProps = { + allFailedTaskList: Task[] + allSuccessTaskList: Task[] + allTaskList: Task[] + appId: string + appSourceType: AppSourceType + completionFiles: VisionFile[] + controlRetry: number + controlSend: number + controlStopResponding: number + exportRes: Record[] + handleCompleted: (completionRes: string, taskId?: number, isSuccess?: boolean) => void + handleRetryAllFailedTask: () => void + handleSaveMessage: (messageId: string) => Promise + inputs: Record + isCallBatchAPI: boolean + isPC: boolean + isShowResultPanel: boolean + isWorkflow: boolean + moreLikeThisEnabled: boolean + noPendingTask: boolean + onHideResultPanel: () => void + onRunControlChange: (control: TextGenerationRunControl | null) => void + onRunStart: () => void + onShowResultPanel: () => void + promptConfig: PromptConfig + resultExisted: boolean + showTaskList: Task[] + siteInfo: SiteInfo + textToSpeechEnabled: boolean + visionConfig: VisionSettings +} + +const TextGenerationResultPanel: FC = ({ + allFailedTaskList, + allSuccessTaskList, + allTaskList, + appId, + appSourceType, + completionFiles, + controlRetry, + controlSend, + controlStopResponding, + exportRes, + handleCompleted, + handleRetryAllFailedTask, + handleSaveMessage, + inputs, + isCallBatchAPI, + isPC, + isShowResultPanel, + isWorkflow, + moreLikeThisEnabled, + noPendingTask, + onHideResultPanel, + onRunControlChange, + onRunStart, + onShowResultPanel, + promptConfig, + resultExisted, + showTaskList, + siteInfo, + textToSpeechEnabled, + visionConfig, +}) => { + const { t } = useTranslation() + + const renderResult = (task?: Task) => ( + + ) + + return ( +
+ {!isPC && ( +
{ + if (isShowResultPanel) + onHideResultPanel() + else + onShowResultPanel() + }} + > +
+
+ )} +
+ {isCallBatchAPI && ( +
+
{t('generation.executions', { ns: 'share', num: allTaskList.length })}
+ {allSuccessTaskList.length > 0 && ( + + )} +
+ )} +
+ {isCallBatchAPI ? showTaskList.map(task => renderResult(task)) : renderResult()} + {!noPendingTask && ( +
+ +
+ )} +
+ {isCallBatchAPI && allFailedTaskList.length > 0 && ( +
+ +
{t('generation.batchFailed.info', { ns: 'share', num: allFailedTaskList.length })}
+
+
{t('generation.batchFailed.retry', { ns: 'share' })}
+
+ )} +
+
+ ) +} + +export default TextGenerationResultPanel diff --git a/web/app/components/share/text-generation/text-generation-sidebar.tsx b/web/app/components/share/text-generation/text-generation-sidebar.tsx new file mode 100644 index 0000000000..70b65f59e9 --- /dev/null +++ b/web/app/components/share/text-generation/text-generation-sidebar.tsx @@ -0,0 +1,177 @@ +import type { FC, RefObject } from 'react' +import type { InputValueTypes, TextGenerationCustomConfig, TextGenerationRunControl } from './types' +import type { PromptConfig, SavedMessage, TextToSpeechConfig } from '@/models/debug' +import type { SiteInfo } from '@/models/share' +import type { VisionFile, VisionSettings } from '@/types/app' +import type { SystemFeatures } from '@/types/feature' +import { useTranslation } from 'react-i18next' +import SavedItems from '@/app/components/app/text-generate/saved-items' +import AppIcon from '@/app/components/base/app-icon' +import Badge from '@/app/components/base/badge' +import DifyLogo from '@/app/components/base/logo/dify-logo' +import { appDefaultIconBackground } from '@/config' +import { AccessMode } from '@/models/access-control' +import { cn } from '@/utils/classnames' +import TabHeader from '../../base/tab-header' +import MenuDropdown from './menu-dropdown' +import RunBatch from './run-batch' +import RunOnce from './run-once' + +type TextGenerationSidebarProps = { + accessMode: AccessMode + allTasksRun: boolean + currentTab: string + customConfig: TextGenerationCustomConfig | null + inputs: Record + inputsRef: RefObject> + isInstalledApp: boolean + isPC: boolean + isWorkflow: boolean + onBatchSend: (data: string[][]) => void + onInputsChange: (inputs: Record) => void + onRemoveSavedMessage: (messageId: string) => Promise + onRunOnceSend: () => void + onTabChange: (tab: string) => void + onVisionFilesChange: (files: VisionFile[]) => void + promptConfig: PromptConfig + resultExisted: boolean + runControl: TextGenerationRunControl | null + savedMessages: SavedMessage[] + siteInfo: SiteInfo + systemFeatures: SystemFeatures + textToSpeechConfig: TextToSpeechConfig | null + visionConfig: VisionSettings +} + +const TextGenerationSidebar: FC = ({ + accessMode, + allTasksRun, + currentTab, + customConfig, + inputs, + inputsRef, + isInstalledApp, + isPC, + isWorkflow, + onBatchSend, + onInputsChange, + onRemoveSavedMessage, + onRunOnceSend, + onTabChange, + onVisionFilesChange, + promptConfig, + resultExisted, + runControl, + savedMessages, + siteInfo, + systemFeatures, + textToSpeechConfig, + visionConfig, +}) => { + const { t } = useTranslation() + + return ( +
+
+
+ +
{siteInfo.title}
+ +
+ {siteInfo.description && ( +
{siteInfo.description}
+ )} + , + extra: savedMessages.length > 0 + ? ( + + {savedMessages.length} + + ) + : null, + }] + : []), + ]} + value={currentTab} + onChange={onTabChange} + /> +
+
+
+ +
+
+ +
+ {currentTab === 'saved' && ( + onTabChange('create')} + /> + )} +
+ {!customConfig?.remove_webapp_brand && ( +
+
{t('chat.poweredBy', { ns: 'share' })}
+ {systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo + ? logo + : customConfig?.replace_webapp_logo + ? logo + : } +
+ )} +
+ ) +} + +export default TextGenerationSidebar diff --git a/web/app/components/share/text-generation/types.ts b/web/app/components/share/text-generation/types.ts index 144ced28a2..b2c1a605c3 100644 --- a/web/app/components/share/text-generation/types.ts +++ b/web/app/components/share/text-generation/types.ts @@ -1,3 +1,5 @@ +import type { FileEntity } from '@/app/components/base/file-uploader/types' + type TaskParam = { inputs: Record } @@ -15,5 +17,22 @@ export enum TaskStatus { failed = 'failed', } -// eslint-disable-next-line ts/no-explicit-any -export type InputValueTypes = string | boolean | number | string[] | object | undefined | any +export type InputValueTypes + = | string + | boolean + | number + | string[] + | Record + | FileEntity + | FileEntity[] + | undefined + +export type TextGenerationRunControl = { + onStop: () => Promise | void + isStopping: boolean +} + +export type TextGenerationCustomConfig = Record & { + remove_webapp_brand?: boolean + replace_webapp_logo?: string +} diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 8f1b202044..041ff78a1b 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -5616,12 +5616,6 @@ "app/components/share/text-generation/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 6 - }, - "ts/no-explicit-any": { - "count": 8 } }, "app/components/share/text-generation/info-modal.tsx": {