Compare commits

..

15 Commits

Author SHA1 Message Date
c29af7a15f feat(auth): implement credential management modals and hooks
- Added  component for handling delete confirmation and edit modals.
- Introduced  and  for displaying credentials.
- Created custom hooks  and  to manage credential actions and modal states.
- Refactored  component to utilize new hooks and components for improved readability and maintainability.
2026-01-14 17:05:17 +08:00
01f17b7ddc refactor(http_request_node): apply DI for http request node (#30509)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-14 14:19:48 +08:00
yyh
14b2e5bd0d refactor(web): MCP tool availability to context-based version gating (#30955) 2026-01-14 13:40:16 +08:00
d095bd413b fix: fix LOOP_CHILDREN_Z_INDEX (#30719) 2026-01-14 10:22:31 +08:00
3473ff7ad1 fix: use Factory to create repository in Aliyun Trace (#30899) 2026-01-14 10:21:46 +08:00
138c56bd6e fix(logstore): prevent SQL injection, fix serialization issues, and optimize initialization (#30697) 2026-01-14 10:21:26 +08:00
c327d0bb44 fix: Correction to the full name of Volc TOS (#30741) 2026-01-14 10:11:30 +08:00
e4b97fba29 chore(deps): bump azure-core from 1.36.0 to 1.38.0 in /api (#30941)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-14 10:10:49 +08:00
7f9884e7a1 feat: Add option to delete or keep API keys when uninstalling plugin (#28201)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-01-14 10:09:30 +08:00
e389cd1665 chore(deps): bump filelock from 3.20.0 to 3.20.3 in /api (#30939)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-14 09:56:02 +08:00
87f348a0de feat: change param to pydantic model (#30870) 2026-01-14 09:46:41 +08:00
206706987d refactor(variables): clarify base vs union type naming (#30634)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-13 23:39:34 +09:00
91da784f84 refactor: init orpc contract (#30885)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-01-13 23:38:28 +09:00
a129e684cc feat: inject traceparent in enterprise api (#30895)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-13 23:37:39 +09:00
fe07c810ba fix: fix instance is not bind to session (#30913) 2026-01-13 21:15:21 +08:00
123 changed files with 4150 additions and 1208 deletions

View File

@ -90,7 +90,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 24
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -16,6 +16,10 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20, 22]
defaults:
run:
working-directory: sdks/nodejs-client
@ -25,10 +29,10 @@ jobs:
with:
persist-credentials: false
- name: Use Node.js
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v6
with:
node-version: 24
node-version: ${{ matrix.node-version }}
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'

View File

@ -57,7 +57,7 @@ jobs:
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 24
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -31,7 +31,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 24
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(

View File

@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
status: str | None = Field(None, title="Status", description="Document status.")
fetch_val: str = Field("false", alias="fetch")
register_schema_models(
console_ns,
KnowledgeConfig,
@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
limit = param.limit
search = param.search
sort = param.sort_by
status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:

View File

@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""

View File

@ -1,6 +1,6 @@
import logging
from core.variables import Variable
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,

View File

@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(

View File

@ -55,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@ -275,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

View File

@ -7,8 +7,8 @@ from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
@ -230,30 +229,32 @@ class WorkflowTool(Tool):
"""
Resolve user from database (worker/Celery context).
"""
with session_factory.create_session() as session:
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = session.scalar(user_stmt)
if user:
user.current_tenant = tenant
session.expunge(user)
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = session.scalar(end_user_stmt)
if end_user:
session.expunge(end_user)
return end_user
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@ -265,22 +266,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
if not workflow:
raise ValueError("workflow not found or not published")
return workflow
session.expunge(workflow)
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
return app
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""

View File

@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
__all__ = [
@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
"VariableBase",
]

View File

@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]

View File

@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
class Variable(Segment):
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable):
class StringVariable(StringSegment, VariableBase):
pass
class FloatVariable(FloatSegment, Variable):
class FloatVariable(FloatSegment, VariableBase):
pass
class IntegerVariable(IntegerSegment, Variable):
class IntegerVariable(IntegerSegment, VariableBase):
pass
class ObjectVariable(ObjectSegment, Variable):
class ObjectVariable(ObjectSegment, VariableBase):
pass
class ArrayVariable(ArraySegment, Variable):
class ArrayVariable(ArraySegment, VariableBase):
pass
@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
class FileVariable(FileSegment, VariableBase):
pass
class BooleanVariable(BooleanSegment, Variable):
class BooleanVariable(BooleanSegment, VariableBase):
pass
@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
# - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `VariableBase`.
Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]

View File

@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import Variable
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
:param variable: The `VariableBase` instance containing the updated value.
"""
pass

View File

@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
class CommandType(StrEnum):
@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):

View File

@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client
self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
file_manager.download(file),
self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
"put": self._http_client.put,
"delete": self._http_client.delete,
"patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
response: httpx.Response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response

View File

@ -1,10 +1,11 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol = ssrf_proxy,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
http_client=self._http_client,
file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,

View File

@ -11,7 +11,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
dict[str, VariableUnion],
dict[str, Variable],
LLMUsage,
]
],
@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})

View File

@ -1,16 +1,21 @@
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
@ -43,6 +48,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
http_request_http_client: HttpClientProtocol = ssrf_proxy,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -61,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._http_request_http_client = http_request_http_client
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -113,6 +124,7 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
@ -122,6 +134,17 @@ class DifyNodeFactory(NodeFactory):
template_renderer=self._template_renderer,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -0,0 +1,29 @@
from typing import Protocol
import httpx
from core.file import File
class HttpClientProtocol(Protocol):
@property
def max_retries_exceeded_error(self) -> type[Exception]: ...
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
class FileManagerProtocol(Protocol):
def download(self, f: File, /) -> bytes: ...

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
variable: Variable,
variable: VariableBase,
operation: Operation,
value: Any,
):

View File

@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
conversation_variables: Sequence[VariableUnion] = Field(
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:

View File

@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
:return: a list of VariableBase objects that match the provided selectors.
"""
pass
@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []

View File

@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Logstore is considered enabled when:
1. All required Aliyun SLS environment variables are set
2. At least one repository configuration points to a logstore implementation
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
# Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
sls_vars_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
if not sls_vars_set:
return False
return all_set
# Check if any repository configuration points to logstore implementation
repository_configs = [
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_RUN_REPOSITORY,
]
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
if not uses_logstore:
return False
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
logger.info("Initializing Aliyun SLS Logstore...")
# Create logstore client and initialize project/logstores/indexes
# Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured
logger.exception(
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
"Application will continue but logstore features will NOT work.",
os.environ.get("ALIYUN_SLS_ENDPOINT"),
os.environ.get("ALIYUN_SLS_REGION"),
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
)
# Don't raise - allow application to continue even if logstore setup fails

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
import socket
import threading
import time
from collections.abc import Sequence
@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.log_enabled: bool = (
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
)
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Get timeout configuration
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
# Pre-check endpoint connectivity to prevent indefinite hangs
self._check_endpoint_connectivity(self.endpoint, check_timeout)
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
@staticmethod
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
"""
Check if the SLS endpoint is reachable before creating LogClient.
Prevents indefinite hangs when the endpoint is unreachable.
Args:
endpoint: SLS endpoint URL
timeout: Connection timeout in seconds
Raises:
ConnectionError: If endpoint is not reachable
"""
# Parse endpoint URL to extract hostname and port
from urllib.parse import urlparse
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
hostname = parsed_url.hostname
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
if not hostname:
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
sock = None
try:
# Create socket and set timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
except Exception as e:
# Catch all exceptions and provide clear error message
error_type = type(e).__name__
raise ConnectionError(
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
) from e
finally:
# Ensure socket is properly closed
if sock:
try:
sock.close()
except Exception: # noqa: S110
pass # Ignore errors during cleanup
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@ -220,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
logger.info("Using SDK mode for project %s", self.project_name)
logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
query,
sql,
full_query,
)
try:
@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from sqlalchemy import create_engine
from configs import dify_config
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
"""Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
# Pool configuration
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
# Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
logger.debug("Using SDK mode for host=%s", pg_host)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
# Build connection URL
from urllib.parse import quote_plus
username = quote_plus(self._access_key_id)
password = quote_plus(self._access_key_secret)
database_url = (
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
# Create SQLAlchemy engine with connection pool
self._engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
pool_timeout=30,
connect_args={
"connect_timeout": 5,
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
},
)
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
pool_size,
pool_recycle,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
if self._engine:
try:
self._pg_pool.closeall()
self._engine.dispose()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.debug("Failed to dispose engine during cleanup, ignoring")
self._engine = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
if not self._engine:
raise RuntimeError("SQLAlchemy engine is not initialized")
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
connection = self._engine.raw_connection()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
connection.autocommit = True # SLS PG protocol does not support transactions
yield connection
except Exception:
raise
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
connection.close()
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
"""Dispose SQLAlchemy engine and close all connections."""
if self._engine:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
self._engine.dispose()
logger.info("SQLAlchemy engine disposed")
except Exception:
logger.exception("Failed to close PG connection pool")
logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
"""Check if error is retriable (connection-related issues)."""
# Check for psycopg2 connection errors directly
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
"operational error",
"interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,29 @@
"""
LogStore repository utilities.
"""
from typing import Any
def safe_float(value: Any, default: float = 0.0) -> float:
"""
Safely convert a value to float, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
Safely convert a value to int, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default

View File

@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
model.index = safe_int(data.get("index", 0))
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_id = escape_identifier(workflow_id)
escaped_node_id = escape_identifier(node_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_run_id = escape_identifier(workflow_run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
query = (
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_run_id: {escaped_workflow_run_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Escape parameters to prevent SQL injection
escaped_execution_id = escape_identifier(execution_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
if tenant_id:
escaped_tenant_id = escape_identifier(tenant_id)
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
else:
tenant_filter = ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
# Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
query = (
f"id:{escape_logstore_query_value(execution_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
)
else:
query = f"id: {execution_id}"
query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now

View File

@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
- SQL injection prevention via parameter escaping
"""
import logging
@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
model.total_tokens = safe_int(data.get("total_tokens", 0))
model.total_steps = safe_int(data.get("total_steps", 0))
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
# Use safe conversion to handle 'null' strings and None values
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build triggered_from filter with escaped values
# Support both enum and string values for triggered_from
triggered_from_filter = " OR ".join(
[
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
for tf in triggered_from_list
]
)
# Build status filter with escaped value
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Escape parameters to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
# Note: Values must be quoted in LogStore query syntax to prevent injection
query = (
f"id:{escape_logstore_query_value(run_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
f"and app_id:{escape_logstore_query_value(app_id)}"
)
from_time = 0
to_time = int(time.time()) # now
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Escape parameter to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
# Note: Values must be quoted in LogStore query syntax
query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter
time_filter = ""
if time_range:
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
escaped_status = escape_sql_string(status)
if status == "running":
# Running status requires window function
sql = f"""
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),

View File

@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
For LogStore implementation, this is a no-op for the LogStore write because save()
already writes all fields including inputs, process_data, and outputs. The caller
typically calls save() first to persist status/metadata, then calls save_execution_data()
to persist data fields. Since LogStore writes complete records atomically, we don't
need a separate write here to avoid duplicate records.
However, if dual-write is enabled, we still need to call the SQL repository's
save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
logger.debug(
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
# Calling save() again would create a duplicate record in the append-only LogStore
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save_execution_data(execution)
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Uses LogStore SQL query with window function to get the latest version of each node execution.
This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
This method uses ROW_NUMBER() window function partitioned by node_execution_id
to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build SQL query with deduplication using window function
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
# ensures we get the latest version of each node execution
# Build ORDER BY clause
# Escape parameters to prevent SQL injection
escaped_workflow_run_id = escape_identifier(workflow_run_id)
escaped_tenant_id = escape_identifier(self._tenant_id)
# Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
# Build app_id filter for subquery
app_id_filter = ""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
escaped_app_id = escape_identifier(self._app_id)
app_id_filter = f" AND app_id='{escaped_app_id}'"
# Use window function to get latest version of each node execution
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{escaped_workflow_run_id}'
AND tenant_id='{escaped_tenant_id}'
{app_id_filter}
) t
WHERE rn = 1
"""
if order_clause:
sql += f" {order_clause}"

View File

@ -0,0 +1,134 @@
"""
SQL Escape Utility for LogStore Queries
This module provides escaping utilities to prevent injection attacks in LogStore queries.
LogStore supports two query modes:
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
Key Security Concerns:
- Prevent tenant A from accessing tenant B's data via injection
- SLS queries are read-only, so we focus on data access control
- Different escaping strategies for SQL vs LogStore query syntax
"""
def escape_sql_string(value: str) -> str:
"""
Escape a string value for safe use in SQL queries.
This function escapes single quotes by doubling them, which is the standard
SQL escaping method. This prevents SQL injection by ensuring that user input
cannot break out of string literals.
Args:
value: The string value to escape
Returns:
Escaped string safe for use in SQL queries
Examples:
>>> escape_sql_string("normal_value")
"normal_value"
>>> escape_sql_string("value' OR '1'='1")
"value'' OR ''1''=''1"
>>> escape_sql_string("tenant's_id")
"tenant''s_id"
Security:
- Prevents breaking out of string literals
- Stops injection attacks like: ' OR '1'='1
- Protects against cross-tenant data access
"""
if not value:
return value
# Escape single quotes by doubling them (standard SQL escaping)
# This prevents breaking out of string literals in SQL queries
return value.replace("'", "''")
def escape_identifier(value: str) -> str:
"""
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
This function is for PG protocol mode (SQL syntax).
For SDK mode, use escape_logstore_query_value() instead.
Args:
value: The identifier value to escape
Returns:
Escaped identifier safe for use in SQL queries
Examples:
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
"550e8400-e29b-41d4-a716-446655440000"
>>> escape_identifier("tenant_id' OR '1'='1")
"tenant_id'' OR ''1''=''1"
Security:
- Prevents SQL injection via identifiers
- Stops cross-tenant access attempts
- Works for UUIDs, alphanumeric IDs, and similar identifiers
"""
# For identifiers, use the same escaping as strings
# This is simple and effective for preventing injection
return escape_sql_string(value)
def escape_logstore_query_value(value: str) -> str:
"""
Escape value for LogStore query syntax (SDK mode).
LogStore query syntax rules:
1. Keywords (and/or/not) are case-insensitive
2. Single quotes are ordinary characters (no special meaning)
3. Double quotes wrap values: key:"value"
4. Backslash is the escape character:
- \" for double quote inside value
- \\ for backslash itself
5. Parentheses can change query structure
To prevent injection:
- Wrap value in double quotes to treat special chars as literals
- Escape backslashes and double quotes using backslash
Args:
value: The value to escape for LogStore query syntax
Returns:
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
Examples:
>>> escape_logstore_query_value("normal_value")
'"normal_value"'
>>> escape_logstore_query_value("value or field:evil")
'"value or field:evil"' # 'or' and ':' are now literals
>>> escape_logstore_query_value('value"test')
'"value\\"test"' # Internal double quote escaped
>>> escape_logstore_query_value('value\\test')
'"value\\\\test"' # Backslash escaped
Security:
- Prevents injection via and/or/not keywords
- Prevents injection via colons (:)
- Prevents injection via parentheses
- Protects against cross-tenant data access
Note:
Escape order is critical: backslash first, then double quotes.
Otherwise, we'd double-escape the escape character itself.
"""
if not value:
return '""'
# IMPORTANT: Escape backslashes FIRST, then double quotes
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
escaped = value.replace("\\", "\\\\") # \ -> \\
escaped = escaped.replace('"', '\\"') # " -> \"
# Wrap in double quotes to treat as literal string
# This prevents and/or/not/:/() from being interpreted as operators
return f'"{escaped}"'

View File

@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
result: Variable
result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
) -> VariableBase:
if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
Variable,
VariableBase,
variable_class(
id=id,
name=name,

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -21,8 +20,8 @@ class SimpleFeedback(ResponseModel):
class RetrieverResource(ResponseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
message_id: str = Field(default_factory=lambda: str(uuid4()))
id: str
message_id: str
position: int
dataset_id: str | None = None
dataset_name: str | None = None

View File

@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
if isinstance(value, Variable):
if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> WorkflowType:
def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
var: Variable,
var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var: Variable) -> Variable:
def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]):
def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship(
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@ -692,7 +690,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
file: Mapped[UploadFile | None] = orm.relationship(
file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
)
return obj
def to_variable(self) -> Variable:
def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship(
upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship(
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.4"
version = "1.11.3"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from models import ConversationVariable
@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)

View File

@ -1,9 +1,14 @@
import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
from core.helper.trace_id_helper import generate_traceparent_header
logger = logging.getLogger(__name__)
class BaseRequest:
proxies: Mapping[str, str] | None = {
@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
try:
# ensure traceparent even when OTEL is disabled
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from sqlalchemy import select
from yarl import URL
from configs import dify_config
@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
# Get plugin info before uninstalling to delete associated credentials
try:
plugins = manager.list_plugins(tenant_id)
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if plugin:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
# Delete provider credentials that match this plugin
credentials = db.session.scalars(
select(ProviderCredential).where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.like(f"{plugin_id}/%"),
)
).all()
for cred in credentials:
db.session.delete(cred)
db.session.commit()
logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
except Exception as e:
logger.warning("Failed to delete credentials: %s", e)
# Continue with uninstall even if credential deletion fails
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod

View File

@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""

View File

@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable]
_fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {}
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.

View File

@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.variables import VariableBase
from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool

View File

@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M
def scalar(self, _stmt):
return self.results.pop(0)
# SQLAlchemy Session APIs used by code under test
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
# support `with session_factory.create_session() as session:`
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
tenant = SimpleNamespace(id="tenant_id")
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
# Monkeypatch session factory to return our stub session
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([tenant, None, end_user]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt
def scalar(self, _stmt):
return self.results.pop(0)
db_stub = SimpleNamespace(session=StubSession([None]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
# Monkeypatch session factory to return our stub session with no tenant
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([None]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),

View File

@ -35,7 +35,6 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -96,7 +95,7 @@ class _Segments(BaseModel):
class _Variables(BaseModel):
variables: list[VariableUnion]
variables: list[Variable]
def create_test_file(
@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
all_variables: list[Variable] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),

View File

@ -11,7 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
from core.variables.variables import VariableBase
def test_frozen_variables():
@ -76,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text")
var: VariableBase = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

View File

@ -24,7 +24,7 @@ from core.variables.variables import (
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.runtime import VariablePool
@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
env_vars: list[Variable] = [
StringVariable(
id="env_string_id",
name="env_string",
@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
conv_vars: list[Variable] = [
StringVariable(
id="conv_string_id",
name="conv_string",

View File

@ -0,0 +1 @@
"""LogStore extension unit tests."""

View File

@ -0,0 +1,469 @@
"""
Unit tests for SQL escape utility functions.
These tests ensure that SQL injection attacks are properly prevented
in LogStore queries, particularly for cross-tenant access scenarios.
"""
import pytest
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
class TestEscapeSQLString:
"""Test escape_sql_string function."""
def test_escape_empty_string(self):
"""Test escaping empty string."""
assert escape_sql_string("") == ""
def test_escape_normal_string(self):
"""Test escaping string without special characters."""
assert escape_sql_string("tenant_abc123") == "tenant_abc123"
assert escape_sql_string("app-uuid-1234") == "app-uuid-1234"
def test_escape_single_quote(self):
"""Test escaping single quote."""
# Single quote should be doubled
assert escape_sql_string("tenant'id") == "tenant''id"
assert escape_sql_string("O'Reilly") == "O''Reilly"
def test_escape_multiple_quotes(self):
"""Test escaping multiple single quotes."""
assert escape_sql_string("a'b'c") == "a''b''c"
assert escape_sql_string("'''") == "''''''"
# === SQL Injection Attack Scenarios ===
def test_prevent_boolean_injection(self):
"""Test prevention of boolean injection attacks."""
# Classic OR 1=1 attack
malicious_input = "tenant' OR '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR ''1''=''1"
# When used in SQL, this becomes a safe string literal
sql = f"WHERE tenant_id='{escaped}'"
assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'"
# The entire input is now a string literal that won't match any tenant
def test_prevent_or_injection(self):
"""Test prevention of OR-based injection."""
malicious_input = "tenant_a' OR tenant_id='tenant_b"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant_a'' OR tenant_id=''tenant_b"
sql = f"WHERE tenant_id='{escaped}'"
# The OR is now part of the string literal, not SQL logic
assert "OR tenant_id=" in sql
# The SQL has: opening ', doubled internal quotes '', and closing '
assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'"
def test_prevent_union_injection(self):
"""Test prevention of UNION-based injection."""
malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1"
# UNION becomes part of the string literal
assert "UNION" in escaped
assert escaped.count("''") == 4 # All internal quotes are doubled
def test_prevent_comment_injection(self):
"""Test prevention of comment-based injection."""
# SQL comment to bypass remaining conditions
malicious_input = "tenant' --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' --"
sql = f"WHERE tenant_id='{escaped}' AND deleted=false"
# The -- is now inside the string, not a SQL comment
assert "--" in sql
assert "AND deleted=false" in sql # This part is NOT commented out
def test_prevent_semicolon_injection(self):
"""Test prevention of semicolon-based multi-statement injection."""
malicious_input = "tenant'; DROP TABLE users; --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant''; DROP TABLE users; --"
# Semicolons and DROP are now part of the string
assert "DROP TABLE" in escaped
def test_prevent_time_based_blind_injection(self):
"""Test prevention of time-based blind SQL injection."""
malicious_input = "tenant' AND SLEEP(5) --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' AND SLEEP(5) --"
# SLEEP becomes part of the string
assert "SLEEP" in escaped
def test_prevent_wildcard_injection(self):
"""Test prevention of wildcard-based injection."""
malicious_input = "tenant' OR tenant_id LIKE '%"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR tenant_id LIKE ''%"
# The LIKE and wildcard are now part of the string
assert "LIKE" in escaped
def test_prevent_null_byte_injection(self):
"""Test handling of null bytes."""
# Null bytes can sometimes bypass filters
malicious_input = "tenant\x00' OR '1'='1"
escaped = escape_sql_string(malicious_input)
# Null byte is preserved, but quote is escaped
assert "''1''=''1" in escaped
# === Real-world SAAS Scenarios ===
def test_cross_tenant_access_attempt(self):
"""Test prevention of cross-tenant data access."""
# Attacker tries to access another tenant's data
attacker_input = "tenant_b' OR tenant_id='tenant_a"
escaped = escape_sql_string(attacker_input)
sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'"
# The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a"
# which doesn't exist - preventing access to either tenant's data
assert "tenant_b'' OR tenant_id=''tenant_a" in sql
def test_cross_app_access_attempt(self):
"""Test prevention of cross-application data access."""
attacker_input = "app1' OR app_id='app2"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE app_id='{escaped}'"
# Cannot access app2's data
assert "app1'' OR app_id=''app2" in sql
def test_bypass_status_filter(self):
"""Test prevention of bypassing status filters."""
# Try to see all statuses instead of just 'running'
attacker_input = "running' OR status LIKE '%"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE status='{escaped}'"
# Status condition is not bypassed
assert "running'' OR status LIKE ''%" in sql
# === Edge Cases ===
def test_escape_only_quotes(self):
"""Test string with only quotes."""
assert escape_sql_string("'") == "''"
assert escape_sql_string("''") == "''''"
def test_escape_mixed_content(self):
"""Test string with mixed quotes and other chars."""
input_str = "It's a 'test' of O'Reilly's code"
escaped = escape_sql_string(input_str)
assert escaped == "It''s a ''test'' of O''Reilly''s code"
def test_escape_unicode_with_quotes(self):
"""Test Unicode strings with quotes."""
input_str = "租户' OR '1'='1"
escaped = escape_sql_string(input_str)
assert escaped == "租户'' OR ''1''=''1"
class TestEscapeIdentifier:
"""Test escape_identifier function."""
def test_escape_uuid(self):
"""Test escaping UUID identifiers."""
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert escape_identifier(uuid) == uuid
def test_escape_alphanumeric_id(self):
"""Test escaping alphanumeric identifiers."""
assert escape_identifier("tenant_123") == "tenant_123"
assert escape_identifier("app-abc-123") == "app-abc-123"
def test_escape_identifier_with_quote(self):
"""Test escaping identifier with single quote."""
malicious = "tenant' OR '1'='1"
escaped = escape_identifier(malicious)
assert escaped == "tenant'' OR ''1''=''1"
def test_identifier_injection_attempt(self):
"""Test prevention of injection through identifiers."""
# Common identifier injection patterns
test_cases = [
("id' OR '1'='1", "id'' OR ''1''=''1"),
("id'; DROP TABLE", "id''; DROP TABLE"),
("id' UNION SELECT", "id'' UNION SELECT"),
]
for malicious, expected in test_cases:
assert escape_identifier(malicious) == expected
class TestSQLInjectionIntegration:
"""Integration tests simulating real SQL construction scenarios."""
def test_complete_where_clause_safety(self):
"""Test that a complete WHERE clause is safe from injection."""
# Simulating typical query construction
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
run_id = "run' --"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
escaped_run = escape_identifier(run_id)
sql = f"""
SELECT * FROM workflow_runs
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND id='{escaped_run}'
"""
# Verify all special characters are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
assert "run'' --" in sql
# Verify SQL structure is preserved (3 conditions with AND)
assert sql.count("AND") == 2
def test_multiple_conditions_with_injection_attempts(self):
"""Test multiple conditions all attempting injection."""
conditions = {
"tenant_id": "t1' OR tenant_id='t2",
"app_id": "a1' OR app_id='a2",
"status": "running' OR '1'='1",
}
where_parts = []
for field, value in conditions.items():
escaped = escape_sql_string(value)
where_parts.append(f"{field}='{escaped}'")
where_clause = " AND ".join(where_parts)
# All injection attempts are neutralized
assert "t1'' OR tenant_id=''t2" in where_clause
assert "a1'' OR app_id=''a2" in where_clause
assert "running'' OR ''1''=''1" in where_clause
# AND structure is preserved
assert where_clause.count(" AND ") == 2
@pytest.mark.parametrize(
("attack_vector", "description"),
[
("' OR '1'='1", "Boolean injection"),
("' OR '1'='1' --", "Boolean with comment"),
("' UNION SELECT * FROM users --", "Union injection"),
("'; DROP TABLE workflow_runs; --", "Destructive command"),
("' AND SLEEP(10) --", "Time-based blind"),
("' OR tenant_id LIKE '%", "Wildcard injection"),
("admin' --", "Comment bypass"),
("' OR 1=1 LIMIT 1 --", "Limit bypass"),
],
)
def test_common_injection_vectors(self, attack_vector, description):
"""Test protection against common injection attack vectors."""
escaped = escape_sql_string(attack_vector)
# Build SQL
sql = f"WHERE tenant_id='{escaped}'"
# Verify the attack string is now a safe literal
# The key indicator: all internal single quotes are doubled
internal_quotes = escaped.count("''")
original_quotes = attack_vector.count("'")
# Each original quote should be doubled
assert internal_quotes == original_quotes
# Verify SQL has exactly 2 quotes (opening and closing)
assert sql.count("'") >= 2 # At least opening and closing
def test_logstore_specific_scenario(self):
"""Test SQL injection prevention in LogStore-specific scenarios."""
# Simulate LogStore query with window function
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM workflow_execution_logstore
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND __time__ > 0
) AS subquery WHERE rn = 1
"""
# Complex query structure is maintained
assert "ROW_NUMBER()" in sql
assert "PARTITION BY id" in sql
# Injection attempts are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
# ====================================================================================
# Tests for LogStore Query Syntax (SDK Mode)
# ====================================================================================
class TestLogStoreQueryEscape:
"""Test escape_logstore_query_value for SDK mode query syntax."""
def test_normal_value(self):
"""Test escaping normal alphanumeric value."""
value = "550e8400-e29b-41d4-a716-446655440000"
escaped = escape_logstore_query_value(value)
# Should be wrapped in double quotes
assert escaped == '"550e8400-e29b-41d4-a716-446655440000"'
def test_empty_value(self):
"""Test escaping empty string."""
assert escape_logstore_query_value("") == '""'
def test_value_with_and_keyword(self):
"""Test that 'and' keyword is neutralized when quoted."""
malicious = "value and field:evil"
escaped = escape_logstore_query_value(malicious)
# Should be wrapped in quotes, making 'and' a literal
assert escaped == '"value and field:evil"'
# Simulate using in query
query = f"tenant_id:{escaped}"
assert query == 'tenant_id:"value and field:evil"'
def test_value_with_or_keyword(self):
"""Test that 'or' keyword is neutralized when quoted."""
malicious = "tenant_a or tenant_id:tenant_b"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"tenant_a or tenant_id:tenant_b"'
query = f"tenant_id:{escaped}"
assert "or" in query # Present but as literal string
def test_value_with_not_keyword(self):
"""Test that 'not' keyword is neutralized when quoted."""
malicious = "not field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"not field:value"'
def test_value_with_parentheses(self):
"""Test that parentheses are neutralized when quoted."""
malicious = "(tenant_a or tenant_b)"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"(tenant_a or tenant_b)"'
assert "(" in escaped # Present as literal
assert ")" in escaped # Present as literal
def test_value_with_colon(self):
"""Test that colons are neutralized when quoted."""
malicious = "field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"field:value"'
assert ":" in escaped # Present as literal
def test_value_with_double_quotes(self):
"""Test that internal double quotes are escaped."""
value_with_quotes = 'tenant"test"value'
escaped = escape_logstore_query_value(value_with_quotes)
# Double quotes should be escaped with backslash
assert escaped == '"tenant\\"test\\"value"'
# Should have outer quotes plus escaped inner quotes
assert '\\"' in escaped
def test_value_with_backslash(self):
"""Test that backslashes are escaped."""
value_with_backslash = "tenant\\test"
escaped = escape_logstore_query_value(value_with_backslash)
# Backslash should be escaped
assert escaped == '"tenant\\\\test"'
assert "\\\\" in escaped
def test_value_with_backslash_and_quote(self):
"""Test escaping both backslash and double quote."""
value = 'path\\to\\"file"'
escaped = escape_logstore_query_value(value)
# Both should be escaped
assert escaped == '"path\\\\to\\\\\\"file\\""'
# Verify escape order is correct
assert "\\\\" in escaped # Escaped backslash
assert '\\"' in escaped # Escaped double quote
def test_complex_injection_attempt(self):
"""Test complex injection combining multiple operators."""
malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")'
escaped = escape_logstore_query_value(malicious)
# All special chars should be literals or escaped
assert escaped.startswith('"')
assert escaped.endswith('"')
# Inner double quotes escaped, operators become literals
assert "or" in escaped
assert "and" in escaped
assert '\\"' in escaped # Escaped quotes
def test_only_backslash(self):
"""Test escaping a single backslash."""
assert escape_logstore_query_value("\\") == '"\\\\"'
def test_only_double_quote(self):
"""Test escaping a single double quote."""
assert escape_logstore_query_value('"') == '"\\""'
def test_multiple_backslashes(self):
"""Test escaping multiple consecutive backslashes."""
assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6
def test_escape_sequence_like_input(self):
"""Test that existing escape sequences are properly escaped."""
# Input looks like already escaped, but we still escape it
value = 'value\\"test'
escaped = escape_logstore_query_value(value)
# \\ -> \\\\, " -> \"
assert escaped == '"value\\\\\\"test"'
@pytest.mark.parametrize(
("attack_scenario", "field", "malicious_value"),
[
("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"),
("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"),
("Boolean logic", "status", "succeeded or status:failed"),
("Negation", "tenant_id", "not tenant_a"),
("Field injection", "run_id", "run123 and tenant_id:evil_tenant"),
("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"),
("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'),
("Backslash escape bypass", "app_id", "app\\ and app_id:evil"),
],
)
def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str):
"""Test that various LogStore query injection attempts are neutralized."""
escaped = escape_logstore_query_value(malicious_value)
# Build query
query = f"{field}:{escaped}"
# All operators should be within quoted string (literals)
assert escaped.startswith('"')
assert escaped.endswith('"')
# Verify the full query structure is safe
assert query.count(":") >= 1 # At least the main field:value separator

View File

@ -0,0 +1,59 @@
"""Unit tests for traceparent header propagation in EnterpriseRequest.
This test module verifies that the W3C traceparent header is properly
generated and included in HTTP requests made by EnterpriseRequest.
"""
from unittest.mock import MagicMock, patch
import pytest
from services.enterprise.base import EnterpriseRequest
class TestTraceparentPropagation:
"""Unit tests for traceparent header propagation."""
@pytest.fixture
def mock_enterprise_config(self):
"""Mock EnterpriseRequest configuration."""
with (
patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"),
patch.object(EnterpriseRequest, "secret_key", "test-secret-key"),
patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"),
):
yield
@pytest.fixture
def mock_httpx_client(self):
"""Mock httpx.Client for testing."""
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__.return_value = mock_client
mock_client_class.return_value.__exit__.return_value = None
# Setup default response
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_client.request.return_value = mock_response
yield mock_client
def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client):
"""Test that traceparent header is included when successfully generated."""
# Arrange
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
# Act
EnterpriseRequest.send_request("GET", "/test")
# Assert
mock_httpx_client.request.assert_called_once()
call_args = mock_httpx_client.request.call_args
headers = call_args[1]["headers"]
assert "traceparent" in headers
assert headers["traceparent"] == expected_traceparent
assert headers["Content-Type"] == "application/json"
assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key"

14
api/uv.lock generated
View File

@ -453,15 +453,15 @@ wheels = [
[[package]]
name = "azure-core"
version = "1.36.0"
version = "1.38.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "requests" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" }
sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" },
{ url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" },
]
[[package]]
@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.11.4"
version = "1.11.3"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@ -1965,11 +1965,11 @@ wheels = [
[[package]]
name = "filelock"
version = "3.20.0"
version = "3.20.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" }
sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" },
{ url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" },
]
[[package]]

View File

@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# Workflow log cleanup configuration

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.4
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.4
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.4
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.4
image: langgenius/dify-web:1.11.3
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -704,7 +704,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.4
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -746,7 +746,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.4
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -785,7 +785,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.4
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -815,7 +815,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.4
image: langgenius/dify-web:1.11.3
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -1 +1 @@
24
22.21.1

View File

@ -1,5 +1,5 @@
# base image
FROM node:24-alpine AS base
FROM node:22.21.1-alpine3.23 AS base
LABEL maintainer="takatost@gmail.com"
# if you located in China, you can use aliyun mirror to speed up

View File

@ -8,8 +8,8 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
Before starting the web frontend service, please make sure the following environment is ready.
- [Node.js](https://nodejs.org)
- [pnpm](https://pnpm.io)
- [Node.js](https://nodejs.org) >= v22.11.x
- [pnpm](https://pnpm.io) v10.x
> [!TIP]
> It is recommended to install and enable Corepack to manage package manager versions automatically:

View File

@ -66,9 +66,7 @@ export default function CheckCode() {
setIsLoading(true)
const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token })
if (ret.result === 'success') {
if (ret?.data?.access_token) {
setWebAppAccessToken(ret.data.access_token)
}
setWebAppAccessToken(ret.data.access_token)
const { access_token } = await fetchAccessToken({
appCode: appCode!,
userId: embeddedUserId || undefined,

View File

@ -82,9 +82,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
body: loginData,
})
if (res.result === 'success') {
if (res?.data?.access_token) {
setWebAppAccessToken(res.data.access_token)
}
setWebAppAccessToken(res.data.access_token)
const { access_token } = await fetchAccessToken({
appCode: appCode!,

View File

@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({
}),
}))
vi.mock('@/service/common', () => ({
fetchCurrentWorkspace: vi.fn(),
fetchLangGeniusVersion: vi.fn(),
fetchUserProfile: vi.fn(),
getSystemFeatures: vi.fn(),
}))
vi.mock('@/service/access-control', () => ({
useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args),
useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args),

View File

@ -183,7 +183,6 @@ const AgentTools: FC = () => {
onSelect={handleSelectTool}
onSelectMultiple={handleSelectMultipleTool}
selectedTools={tools as unknown as ToolValue[]}
canChooseMCPTool
/>
</>
)}

View File

@ -170,8 +170,12 @@ describe('useChatWithHistory', () => {
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1')
})
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
expect(result.current.conversationList).toEqual(listData.data)
await waitFor(() => {
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
})
await waitFor(() => {
expect(result.current.conversationList).toEqual(listData.data)
})
})
})

View File

@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({
}))
vi.mock('@/service/billing', () => ({
fetchBillingUrl: vi.fn(),
fetchSubscriptionUrls: vi.fn(),
}))
vi.mock('@/service/client', () => ({
consoleClient: {
billingUrl: vi.fn(),
},
}))
vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: vi.fn(),
}))
@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({
const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockFetchBillingUrl = fetchBillingUrl as Mock
const mockBillingUrl = consoleClient.billingUrl as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock
@ -69,7 +75,7 @@ beforeEach(() => {
vi.clearAllMocks()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
assignedHref = ''
})
@ -143,7 +149,7 @@ describe('CloudPlanItem', () => {
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
expect(mockFetchBillingUrl).not.toHaveBeenCalled()
expect(mockBillingUrl).not.toHaveBeenCalled()
})
it('should open billing portal when upgrading current paid plan', async () => {
@ -162,7 +168,7 @@ describe('CloudPlanItem', () => {
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
await waitFor(() => {
expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1)
expect(mockBillingUrl).toHaveBeenCalledTimes(1)
})
expect(openWindow).toHaveBeenCalledTimes(1)
})

View File

@ -6,7 +6,8 @@ import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
@ -76,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
try {
if (isCurrentPaidPlan) {
await openAsyncWindow(async () => {
const res = await fetchBillingUrl()
const res = await consoleClient.billingUrl()
if (res.url)
return res.url
throw new Error('Failed to open billing page')

View File

@ -362,18 +362,6 @@ describe('PreviewDocumentPicker', () => {
expect(screen.getByText('--')).toBeInTheDocument()
})
it('should render when value prop is omitted (optional)', () => {
const files = createMockDocumentList(2)
const onChange = vi.fn()
// Do not pass `value` at all to verify optional behavior
render(<PreviewDocumentPicker files={files} onChange={onChange} />)
// Renders placeholder for missing name
expect(screen.getByText('--')).toBeInTheDocument()
// Portal wrapper renders
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
})
it('should handle empty files array', () => {
renderComponent({ files: [] })

View File

@ -18,7 +18,7 @@ import DocumentList from './document-list'
type Props = {
className?: string
value?: DocumentItem
value: DocumentItem
files: DocumentItem[]
onChange: (value: DocumentItem) => void
}
@ -30,8 +30,7 @@ const PreviewDocumentPicker: FC<Props> = ({
onChange,
}) => {
const { t } = useTranslation()
const name = value?.name || ''
const extension = value?.extension
const { name, extension } = value
const [open, {
set: setOpen,

View File

@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
category: PluginCategoryEnum.datasource,
exclude,
type: 'plugin',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
else {
@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
query: '',
category: PluginCategoryEnum.datasource,
type: 'plugin',
pageSize: 1000,
page_size: 1000,
exclude,
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
category: PluginCategoryEnum.model,
exclude,
type: 'plugin',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
else {
@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
query: '',
category: PluginCategoryEnum.model,
type: 'plugin',
pageSize: 1000,
page_size: 1000,
exclude,
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -55,7 +55,6 @@ type FormProps<
nodeId?: string
nodeOutputVars?: NodeOutPutVar[]
availableNodes?: Node[]
canChooseMCPTool?: boolean
}
function Form<
@ -81,7 +80,6 @@ function Form<
nodeId,
nodeOutputVars,
availableNodes,
canChooseMCPTool,
}: FormProps<CustomFormSchema>) {
const language = useLanguage()
const [changeKey, setChangeKey] = useState('')
@ -407,7 +405,6 @@ function Form<
value={value[variable] || []}
onChange={item => handleFormChange(variable, item as any)}
supportCollapse
canChooseMCPTool={canChooseMCPTool}
/>
{fieldMoreInfo?.(formSchema)}
{validating && changeKey === variable && <ValidatingTip />}

View File

@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => {
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()
const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => {
const pageSize = pluginsSearchParams.pageSize || 40
const page_size = pluginsSearchParams.page_size || 40
return {
...pluginsSearchParams,
pageSize,
page_size,
}
}, [])
@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => {
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
page_size: 40,
}
}
const params = normalizeParams(queryParams)
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
exclude,
type,
pageSize,
page_size,
} = params
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: {
page: pageParam,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
exclude,
@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
page_size,
}
}
catch {
@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => {
plugins: [],
total: 0,
page: pageParam,
pageSize,
page_size,
}
}
},
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,

View File

@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'
import { marketplaceQuery } from '@/service/client'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceKeys } from './query'
import { marketplaceSearchParamsParsers } from './search-params'
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise<SearchParams>) {
const queryClient = getQueryClientServer()
await queryClient.prefetchQuery({
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }),
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
})
return dehydrate(queryClient)

View File

@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({
// Mock tanstack query
const mockFetchNextPage = vi.fn()
const mockHasNextPage = false
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined
let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null
vi.mock('@tanstack/react-query', () => ({
useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise<unknown>, enabled: boolean }) => {
@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({
}),
useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: {
queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>
getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined
getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined
enabled: boolean
}) => {
// Capture queryFn and getNextPageParam for later testing
@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({
// Call getNextPageParam to increase coverage
if (getNextPageParam) {
// Test with more data available
getNextPageParam({ page: 1, pageSize: 40, total: 100 })
getNextPageParam({ page: 1, page_size: 40, total: 100 })
// Test with no more data
getNextPageParam({ page: 3, pageSize: 40, total: 100 })
getNextPageParam({ page: 3, page_size: 40, total: 100 })
}
return {
data: mockInfiniteQueryData,
@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({
// Mock config
vi.mock('@/config', () => ({
API_PREFIX: '/api',
APP_VERSION: '1.0.0',
IS_MARKETPLACE: false,
MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1',
@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => {
expect(() => {
result.current.queryPlugins({
query: 'test',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
category: 'tool',
pageSize: 20,
page_size: 20,
})
}).not.toThrow()
})
@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({
query: 'test',
type: 'bundle',
pageSize: 40,
page_size: 40,
})
}).not.toThrow()
})
@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({
query: 'test',
category: 'all',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}).not.toThrow()
})
@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => {
expect(() => {
result.current.queryPlugins({
query: 'test',
pageSize: 100,
page_size: 100,
})
}).not.toThrow()
})
@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => {
// Set mock data to have pages
mockInfiniteQueryData = {
pages: [
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 },
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 },
],
}
@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => {
it('should expose page and total from infinite query data', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 },
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 },
],
}
@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => {
it('should return total from first page when query is set and data exists', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [], total: 50, page: 1, pageSize: 40 },
{ plugins: [], total: 50, page: 1, page_size: 40 },
],
}
@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => {
type: 'plugin',
query: 'search test',
category: 'model',
sortBy: 'version_updated_at',
sortOrder: 'ASC',
sort_by: 'version_updated_at',
sort_order: 'ASC',
})
expect(result.current).toBeDefined()
@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => {
// Test with all possible parameters
result.current.queryPlugins({
query: 'comprehensive test',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
category: 'tool',
tags: ['tag1', 'tag2'],
exclude: ['excluded-plugin'],
type: 'plugin',
pageSize: 50,
page_size: 50,
})
expect(result.current).toBeDefined()
@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({
query: 'direct test',
category: 'tool',
sortBy: 'install_count',
sortOrder: 'DESC',
pageSize: 40,
sort_by: 'install_count',
sort_order: 'DESC',
page_size: 40,
})
// Now queryFn should be captured and enabled
@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({
query: 'structure test',
pageSize: 20,
page_size: 20,
})
if (capturedInfiniteQueryFn) {
@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// Verify the returned structure
expect(response).toHaveProperty('plugins')
expect(response).toHaveProperty('total')
expect(response).toHaveProperty('page')
expect(response).toHaveProperty('pageSize')
expect(response).toHaveProperty('page_size')
}
})
})
@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => {
],
total: 5,
page: 1,
pageSize: 40,
page_size: 40,
},
{
plugins: [
@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => {
],
total: 5,
page: 2,
pageSize: 40,
page_size: 40,
},
],
}
@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => {
it('should test hook with pages data for flatMap path', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [], total: 100, page: 1, pageSize: 40 },
{ plugins: [], total: 100, page: 2, pageSize: 40 },
{ plugins: [], total: 100, page: 1, page_size: 40 },
{ plugins: [], total: 100, page: 2, page_size: 40 },
],
}
@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// When error is caught, should return fallback data
expect(response.plugins).toEqual([])
@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => {
// Test getNextPageParam function directly
if (capturedGetNextPageParam) {
// When there are more pages
const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 })
const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 })
expect(nextPage).toBe(2)
// When all data is loaded
const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 })
const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 })
expect(noMorePages).toBeUndefined()
// Edge case: exactly at boundary
const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 })
const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 })
expect(atBoundary).toBeUndefined()
}
})
@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// Catch block should return fallback values
expect(response.plugins).toEqual([])
@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => {
plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }],
total: 10,
page: 1,
pageSize: 40,
page_size: 40,
},
],
}
@ -1489,9 +1490,12 @@ describe('Async Utils', () => {
{ type: 'plugin', org: 'test', name: 'plugin2' },
]
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection', {
@ -1514,19 +1518,26 @@ describe('Async Utils', () => {
})
it('should pass abort signal when provided', async () => {
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const controller = new AbortController()
const { getMarketplacePluginsByCollectionId } = await import('./utils')
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({ signal: controller.signal }),
expect.any(Request),
expect.any(Object),
)
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('test-collection')
})
})
@ -1535,19 +1546,25 @@ describe('Async Utils', () => {
const mockCollections = [
{ name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' },
]
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
let callCount = 0
globalThis.fetch = vi.fn().mockImplementation(() => {
callCount++
if (callCount === 1) {
return Promise.resolve({
json: () => Promise.resolve({ data: { collections: mockCollections } }),
})
return Promise.resolve(
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
}
return Promise.resolve({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
return Promise.resolve(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
})
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
@ -1571,9 +1588,12 @@ describe('Async Utils', () => {
})
it('should append condition and type to URL when provided', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { collections: [] } }),
})
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { collections: [] } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
await getMarketplaceCollectionsAndPlugins({
@ -1581,10 +1601,11 @@ describe('Async Utils', () => {
type: 'bundle',
})
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.stringContaining('condition=category=tool'),
expect.any(Object),
)
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalled()
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('condition=category%3Dtool')
})
})
})

View File

@ -1,22 +1,14 @@
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
import type { PluginsSearchParams } from './types'
import type { MarketPlaceInputs } from '@/contract/router'
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
import { marketplaceQuery } from '@/service/client'
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
// TODO: Avoid manual maintenance of query keys and better service management,
// https://github.com/langgenius/dify/issues/30342
export const marketplaceKeys = {
all: ['marketplace'] as const,
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
}
export function useMarketplaceCollectionsAndPlugins(
collectionsParams: CollectionsAndPluginsSearchParams,
collectionsParams: MarketPlaceInputs['collections']['query'],
) {
return useQuery({
queryKey: marketplaceKeys.collections(collectionsParams),
queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }),
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
})
}
@ -25,11 +17,16 @@ export function useMarketplacePlugins(
queryParams: PluginsSearchParams | undefined,
) {
return useInfiniteQuery({
queryKey: marketplaceKeys.plugins(queryParams),
queryKey: marketplaceQuery.searchAdvanced.queryKey({
input: {
body: queryParams!,
params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' },
},
}),
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,

View File

@ -26,8 +26,8 @@ export function useMarketplaceData() {
query: searchPluginText,
category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType,
tags: filterPluginTags,
sortBy: sort.sortBy,
sortOrder: sort.sortOrder,
sort_by: sort.sortBy,
sort_order: sort.sortOrder,
type: getMarketplaceListFilterType(activePluginType),
}
}, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort])

View File

@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = {
export type PluginsSearchParams = {
query: string
page?: number
pageSize?: number
sortBy?: string
sortOrder?: string
page_size?: number
sort_by?: string
sort_order?: string
category?: string
tags?: string[]
exclude?: string[]

View File

@ -4,14 +4,12 @@ import type {
MarketplaceCollection,
PluginsSearchParams,
} from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
import type { Plugin } from '@/app/components/plugins/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import {
APP_VERSION,
IS_MARKETPLACE,
MARKETPLACE_API_PREFIX,
} from '@/config'
import { postMarketplace } from '@/service/base'
import { marketplaceClient } from '@/service/client'
import { getMarketplaceUrl } from '@/utils/var'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
@ -19,10 +17,6 @@ type MarketplaceFetchOptions = {
signal?: AbortSignal
}
const getMarketplaceHeaders = () => new Headers({
'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0',
})
export const getPluginIconInMarketplace = (plugin: Plugin) => {
if (plugin.type === 'bundle')
return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon`
@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async (
let plugins: Plugin[] = []
try {
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins`
const headers = getMarketplaceHeaders()
const marketplaceCollectionPluginsData = await globalThis.fetch(
url,
{
cache: 'no-store',
method: 'POST',
headers,
signal: options?.signal,
body: JSON.stringify({
category: query?.category,
exclude: query?.exclude,
type: query?.type,
}),
const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({
params: {
collectionId,
},
)
const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json()
plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin))
body: query,
}, {
signal: options?.signal,
})
plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin))
}
// eslint-disable-next-line unused-imports/no-unused-vars
catch (e) {
@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async (
let marketplaceCollections: MarketplaceCollection[] = []
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
try {
let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100`
if (query?.condition)
marketplaceUrl += `&condition=${query.condition}`
if (query?.type)
marketplaceUrl += `&type=${query.type}`
const headers = getMarketplaceHeaders()
const marketplaceCollectionsData = await globalThis.fetch(
marketplaceUrl,
{
headers,
cache: 'no-store',
signal: options?.signal,
const marketplaceCollectionsDataJson = await marketplaceClient.collections({
query: {
...query,
page: 1,
page_size: 100,
},
)
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json()
marketplaceCollections = marketplaceCollectionsDataJson.data.collections || []
}, {
signal: options?.signal,
})
marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || []
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {
const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options)
@ -143,42 +122,42 @@ export const getMarketplacePlugins = async (
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
page_size: 40,
}
}
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
type,
pageSize = 40,
page_size = 40,
} = queryParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
try {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
const res = await marketplaceClient.searchAdvanced({
params: {
kind: type === 'bundle' ? 'bundles' : 'plugins',
},
body: {
page: pageParam,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
type,
},
signal,
})
}, { signal })
const resPlugins = res.data.bundles || res.data.plugins || []
return {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
page_size,
}
}
catch {
@ -186,7 +165,7 @@ export const getMarketplacePlugins = async (
plugins: [],
total: 0,
page: pageParam,
pageSize,
page_size,
}
}
}

View File

@ -0,0 +1,65 @@
import type { PluginPayload } from '../types'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import Confirm from '@/app/components/base/confirm'
import ApiKeyModal from '../authorize/api-key-modal'
type AuthorizedModalsProps = {
pluginPayload: PluginPayload
// Delete confirmation
deleteCredentialId: string | null
doingAction: boolean
onDeleteConfirm: () => void
onDeleteCancel: () => void
// Edit modal
editValues: Record<string, unknown> | null
disabled?: boolean
onEditClose: () => void
onRemove: () => void
onUpdate?: () => void
}
/**
* Component for managing authorized modals (delete confirmation and edit modal)
* Extracted to reduce complexity in the main Authorized component
*/
const AuthorizedModals = ({
pluginPayload,
deleteCredentialId,
doingAction,
onDeleteConfirm,
onDeleteCancel,
editValues,
disabled,
onEditClose,
onRemove,
onUpdate,
}: AuthorizedModalsProps) => {
const { t } = useTranslation()
return (
<>
{deleteCredentialId && (
<Confirm
isShow
title={t('list.delete.title', { ns: 'datasetDocuments' })}
isDisabled={doingAction}
onCancel={onDeleteCancel}
onConfirm={onDeleteConfirm}
/>
)}
{!!editValues && (
<ApiKeyModal
pluginPayload={pluginPayload}
editValues={editValues}
onClose={onEditClose}
onRemove={onRemove}
disabled={disabled || doingAction}
onUpdate={onUpdate}
/>
)}
</>
)
}
export default memo(AuthorizedModals)

View File

@ -0,0 +1,123 @@
import type { Credential } from '../types'
import { memo } from 'react'
import { cn } from '@/utils/classnames'
import Item from './item'
type CredentialItemHandlers = {
onDelete?: (id: string) => void
onEdit?: (id: string, values: Record<string, unknown>) => void
onSetDefault?: (id: string) => void
onRename?: (payload: { credential_id: string, name: string }) => void
onItemClick?: (id: string) => void
}
type CredentialSectionProps = CredentialItemHandlers & {
title: string
credentials: Credential[]
disabled?: boolean
disableRename?: boolean
disableEdit?: boolean
disableDelete?: boolean
disableSetDefault?: boolean
showSelectedIcon?: boolean
selectedCredentialId?: string
}
/**
* Reusable component for rendering a section of credentials
* Used for OAuth, API Key, and extra authorization items
*/
const CredentialSection = ({
title,
credentials,
disabled,
disableRename,
disableEdit,
disableDelete,
disableSetDefault,
showSelectedIcon,
selectedCredentialId,
onDelete,
onEdit,
onSetDefault,
onRename,
onItemClick,
}: CredentialSectionProps) => {
if (!credentials.length)
return null
return (
<div className="p-1">
<div className={cn(
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
showSelectedIcon && 'pl-7',
)}
>
{title}
</div>
{credentials.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
disableRename={disableRename}
disableEdit={disableEdit}
disableDelete={disableDelete}
disableSetDefault={disableSetDefault}
showSelectedIcon={showSelectedIcon}
selectedCredentialId={selectedCredentialId}
onDelete={onDelete}
onEdit={onEdit}
onSetDefault={onSetDefault}
onRename={onRename}
onItemClick={onItemClick}
/>
))}
</div>
)
}
export default memo(CredentialSection)
type ExtraCredentialSectionProps = {
credentials?: Credential[]
disabled?: boolean
onItemClick?: (id: string) => void
showSelectedIcon?: boolean
selectedCredentialId?: string
}
/**
* Specialized section for extra authorization items (read-only)
*/
export const ExtraCredentialSection = memo(({
credentials,
disabled,
onItemClick,
showSelectedIcon,
selectedCredentialId,
}: ExtraCredentialSectionProps) => {
if (!credentials?.length)
return null
return (
<div className="p-1">
{credentials.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
onItemClick={onItemClick}
disableRename
disableEdit
disableDelete
disableSetDefault
showSelectedIcon={showSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))}
</div>
)
})
ExtraCredentialSection.displayName = 'ExtraCredentialSection'

View File

@ -0,0 +1,2 @@
export { useCredentialActions } from './use-credential-actions'
export { useModalState } from './use-modal-state'

View File

@ -0,0 +1,116 @@
import type { MutableRefObject } from 'react'
import type { PluginPayload } from '../../types'
import {
useCallback,
useRef,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import { useToastContext } from '@/app/components/base/toast'
import {
useDeletePluginCredentialHook,
useSetPluginDefaultCredentialHook,
useUpdatePluginCredentialHook,
} from '../../hooks/use-credential'
type UseCredentialActionsOptions = {
pluginPayload: PluginPayload
onUpdate?: () => void
}
type UseCredentialActionsReturn = {
doingAction: boolean
doingActionRef: MutableRefObject<boolean>
pendingOperationCredentialIdRef: MutableRefObject<string | null>
handleSetDoingAction: (doing: boolean) => void
handleDelete: (credentialId: string) => Promise<void>
handleSetDefault: (id: string) => Promise<void>
handleRename: (payload: { credential_id: string, name: string }) => Promise<void>
}
/**
* Custom hook for credential CRUD operations
* Consolidates delete, setDefault, rename actions with shared loading state
*/
export const useCredentialActions = ({
pluginPayload,
onUpdate,
}: UseCredentialActionsOptions): UseCredentialActionsReturn => {
const { t } = useTranslation()
const { notify } = useToastContext()
const [doingAction, setDoingAction] = useState(false)
const doingActionRef = useRef(doingAction)
const pendingOperationCredentialIdRef = useRef<string | null>(null)
const handleSetDoingAction = useCallback((doing: boolean) => {
doingActionRef.current = doing
setDoingAction(doing)
}, [])
const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload)
const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload)
const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload)
const showSuccessNotification = useCallback(() => {
notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
}, [notify, t])
const handleDelete = useCallback(async (credentialId: string) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await deletePluginCredential({ credential_id: credentialId })
showSuccessNotification()
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [deletePluginCredential, onUpdate, showSuccessNotification, handleSetDoingAction])
const handleSetDefault = useCallback(async (id: string) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await setPluginDefaultCredential(id)
showSuccessNotification()
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [setPluginDefaultCredential, onUpdate, showSuccessNotification, handleSetDoingAction])
const handleRename = useCallback(async (payload: {
credential_id: string
name: string
}) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await updatePluginCredential(payload)
showSuccessNotification()
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [updatePluginCredential, showSuccessNotification, handleSetDoingAction, onUpdate])
return {
doingAction,
doingActionRef,
pendingOperationCredentialIdRef,
handleSetDoingAction,
handleDelete,
handleSetDefault,
handleRename,
}
}

View File

@ -0,0 +1,71 @@
import type { MutableRefObject } from 'react'
import {
useCallback,
useState,
} from 'react'
type CredentialValues = Record<string, unknown>
type UseModalStateOptions = {
pendingOperationCredentialIdRef: MutableRefObject<string | null>
}
type UseModalStateReturn = {
// Delete modal state
deleteCredentialId: string | null
openDeleteConfirm: (credentialId?: string) => void
closeDeleteConfirm: () => void
// Edit modal state
editValues: CredentialValues | null
openEditModal: (id: string, values: CredentialValues) => void
closeEditModal: () => void
// Remove action (used from edit modal)
handleRemoveFromEdit: () => void
}
/**
* Custom hook for managing modal states
* Handles delete confirmation and edit modal with shared pending credential tracking
*/
export const useModalState = ({
pendingOperationCredentialIdRef,
}: UseModalStateOptions): UseModalStateReturn => {
const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null)
const [editValues, setEditValues] = useState<CredentialValues | null>(null)
const openDeleteConfirm = useCallback((credentialId?: string) => {
if (credentialId)
pendingOperationCredentialIdRef.current = credentialId
setDeleteCredentialId(pendingOperationCredentialIdRef.current)
}, [pendingOperationCredentialIdRef])
const closeDeleteConfirm = useCallback(() => {
setDeleteCredentialId(null)
pendingOperationCredentialIdRef.current = null
}, [pendingOperationCredentialIdRef])
const openEditModal = useCallback((id: string, values: CredentialValues) => {
pendingOperationCredentialIdRef.current = id
setEditValues(values)
}, [pendingOperationCredentialIdRef])
const closeEditModal = useCallback(() => {
setEditValues(null)
pendingOperationCredentialIdRef.current = null
}, [pendingOperationCredentialIdRef])
const handleRemoveFromEdit = useCallback(() => {
setDeleteCredentialId(pendingOperationCredentialIdRef.current)
}, [pendingOperationCredentialIdRef])
return {
deleteCredentialId,
openDeleteConfirm,
closeDeleteConfirm,
editValues,
openEditModal,
closeEditModal,
handleRemoveFromEdit,
}
}

File diff suppressed because it is too large Load Diff

View File

@ -8,29 +8,23 @@ import {
import {
memo,
useCallback,
useRef,
useMemo,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { useToastContext } from '@/app/components/base/toast'
import Indicator from '@/app/components/header/indicator'
import { cn } from '@/utils/classnames'
import Authorize from '../authorize'
import ApiKeyModal from '../authorize/api-key-modal'
import {
useDeletePluginCredentialHook,
useSetPluginDefaultCredentialHook,
useUpdatePluginCredentialHook,
} from '../hooks/use-credential'
import { CredentialTypeEnum } from '../types'
import Item from './item'
import AuthorizedModals from './authorized-modals'
import CredentialSection, { ExtraCredentialSection } from './credential-section'
import { useCredentialActions, useModalState } from './hooks'
type AuthorizedProps = {
pluginPayload: PluginPayload
@ -53,6 +47,7 @@ type AuthorizedProps = {
onUpdate?: () => void
notAllowCustomCredential?: boolean
}
const Authorized = ({
pluginPayload,
credentials,
@ -75,105 +70,55 @@ const Authorized = ({
notAllowCustomCredential,
}: AuthorizedProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
// Dropdown open state
const [isLocalOpen, setIsLocalOpen] = useState(false)
const mergedIsOpen = isOpen ?? isLocalOpen
const setMergedIsOpen = useCallback((open: boolean) => {
if (onOpenChange)
onOpenChange(open)
onOpenChange?.(open)
setIsLocalOpen(open)
}, [onOpenChange])
const oAuthCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.OAUTH2)
const apiKeyCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.API_KEY)
const pendingOperationCredentialId = useRef<string | null>(null)
const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null)
const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload)
const openConfirm = useCallback((credentialId?: string) => {
if (credentialId)
pendingOperationCredentialId.current = credentialId
setDeleteCredentialId(pendingOperationCredentialId.current)
}, [])
const closeConfirm = useCallback(() => {
setDeleteCredentialId(null)
pendingOperationCredentialId.current = null
}, [])
const [doingAction, setDoingAction] = useState(false)
const doingActionRef = useRef(doingAction)
const handleSetDoingAction = useCallback((doing: boolean) => {
doingActionRef.current = doing
setDoingAction(doing)
}, [])
const handleConfirm = useCallback(async () => {
if (doingActionRef.current)
// Credential actions hook
const {
doingAction,
doingActionRef,
pendingOperationCredentialIdRef,
handleSetDefault,
handleRename,
handleDelete,
} = useCredentialActions({ pluginPayload, onUpdate })
// Modal state management hook
const {
deleteCredentialId,
openDeleteConfirm,
closeDeleteConfirm,
editValues,
openEditModal,
closeEditModal,
handleRemoveFromEdit,
} = useModalState({ pendingOperationCredentialIdRef })
// Handle delete confirmation
const handleDeleteConfirm = useCallback(async () => {
if (doingActionRef.current || !pendingOperationCredentialIdRef.current)
return
if (!pendingOperationCredentialId.current) {
setDeleteCredentialId(null)
return
}
try {
handleSetDoingAction(true)
await deletePluginCredential({ credential_id: pendingOperationCredentialId.current })
notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
onUpdate?.()
setDeleteCredentialId(null)
pendingOperationCredentialId.current = null
}
finally {
handleSetDoingAction(false)
}
}, [deletePluginCredential, onUpdate, notify, t, handleSetDoingAction])
const [editValues, setEditValues] = useState<Record<string, any> | null>(null)
const handleEdit = useCallback((id: string, values: Record<string, any>) => {
pendingOperationCredentialId.current = id
setEditValues(values)
}, [])
const handleRemove = useCallback(() => {
setDeleteCredentialId(pendingOperationCredentialId.current)
}, [])
const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload)
const handleSetDefault = useCallback(async (id: string) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await setPluginDefaultCredential(id)
notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [setPluginDefaultCredential, onUpdate, notify, t, handleSetDoingAction])
const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload)
const handleRename = useCallback(async (payload: {
credential_id: string
name: string
}) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await updatePluginCredential(payload)
notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate])
const unavailableCredentials = credentials.filter(credential => credential.not_allowed_to_use)
const unavailableCredential = credentials.find(credential => credential.not_allowed_to_use && credential.is_default)
await handleDelete(pendingOperationCredentialIdRef.current)
closeDeleteConfirm()
}, [doingActionRef, pendingOperationCredentialIdRef, handleDelete, closeDeleteConfirm])
// Filter credentials by type
const { oAuthCredentials, apiKeyCredentials } = useMemo(() => ({
oAuthCredentials: credentials.filter(c => c.credential_type === CredentialTypeEnum.OAUTH2),
apiKeyCredentials: credentials.filter(c => c.credential_type === CredentialTypeEnum.API_KEY),
}), [credentials])
// Unavailable credentials info
const { unavailableCredentials, hasUnavailableDefault } = useMemo(() => ({
unavailableCredentials: credentials.filter(c => c.not_allowed_to_use),
hasUnavailableDefault: credentials.some(c => c.not_allowed_to_use && c.is_default),
}), [credentials])
return (
<>
@ -188,33 +133,27 @@ const Authorized = ({
onClick={() => setMergedIsOpen(!mergedIsOpen)}
asChild
>
{
renderTrigger
? renderTrigger(mergedIsOpen)
: (
<Button
className={cn(
'w-full',
isOpen && 'bg-components-button-secondary-bg-hover',
)}
>
<Indicator className="mr-2" color={unavailableCredential ? 'gray' : 'green'} />
{credentials.length}
&nbsp;
{
credentials.length > 1
? t('auth.authorizations', { ns: 'plugin' })
: t('auth.authorization', { ns: 'plugin' })
}
{
!!unavailableCredentials.length && (
` (${unavailableCredentials.length} ${t('auth.unavailable', { ns: 'plugin' })})`
)
}
<RiArrowDownSLine className="ml-0.5 h-4 w-4" />
</Button>
)
}
{renderTrigger
? renderTrigger(mergedIsOpen)
: (
<Button
className={cn(
'w-full',
isOpen && 'bg-components-button-secondary-bg-hover',
)}
>
<Indicator className="mr-2" color={hasUnavailableDefault ? 'gray' : 'green'} />
{credentials.length}
&nbsp;
{credentials.length > 1
? t('auth.authorizations', { ns: 'plugin' })
: t('auth.authorization', { ns: 'plugin' })}
{!!unavailableCredentials.length && (
` (${unavailableCredentials.length} ${t('auth.unavailable', { ns: 'plugin' })})`
)}
<RiArrowDownSLine className="ml-0.5 h-4 w-4" />
</Button>
)}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-[100]">
<div className={cn(
@ -223,137 +162,72 @@ const Authorized = ({
)}
>
<div className="py-1">
{
!!extraAuthorizationItems?.length && (
<div className="p-1">
{
extraAuthorizationItems.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
onItemClick={onItemClick}
disableRename
disableEdit
disableDelete
disableSetDefault
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))
}
</div>
)
}
{
!!oAuthCredentials.length && (
<div className="p-1">
<div className={cn(
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
showItemSelectedIcon && 'pl-7',
)}
>
OAuth
</div>
{
oAuthCredentials.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
disableEdit
onDelete={openConfirm}
onSetDefault={handleSetDefault}
onRename={handleRename}
disableSetDefault={disableSetDefault}
onItemClick={onItemClick}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))
}
</div>
)
}
{
!!apiKeyCredentials.length && (
<div className="p-1">
<div className={cn(
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
showItemSelectedIcon && 'pl-7',
)}
>
API Keys
</div>
{
apiKeyCredentials.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
onDelete={openConfirm}
onEdit={handleEdit}
onSetDefault={handleSetDefault}
disableSetDefault={disableSetDefault}
disableRename
onItemClick={onItemClick}
onRename={handleRename}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))
}
</div>
)
}
<ExtraCredentialSection
credentials={extraAuthorizationItems}
disabled={disabled}
onItemClick={onItemClick}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
<CredentialSection
title="OAuth"
credentials={oAuthCredentials}
disabled={disabled}
disableEdit
disableSetDefault={disableSetDefault}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
onDelete={openDeleteConfirm}
onSetDefault={handleSetDefault}
onRename={handleRename}
onItemClick={onItemClick}
/>
<CredentialSection
title="API Keys"
credentials={apiKeyCredentials}
disabled={disabled}
disableRename
disableSetDefault={disableSetDefault}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
onDelete={openDeleteConfirm}
onEdit={openEditModal}
onSetDefault={handleSetDefault}
onRename={handleRename}
onItemClick={onItemClick}
/>
</div>
{
!notAllowCustomCredential && (
<>
<div className="h-[1px] bg-divider-subtle"></div>
<div className="p-2">
<Authorize
pluginPayload={pluginPayload}
theme="secondary"
showDivider={false}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
onUpdate={onUpdate}
/>
</div>
</>
)
}
{!notAllowCustomCredential && (
<>
<div className="h-[1px] bg-divider-subtle"></div>
<div className="p-2">
<Authorize
pluginPayload={pluginPayload}
theme="secondary"
showDivider={false}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
onUpdate={onUpdate}
/>
</div>
</>
)}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
{
deleteCredentialId && (
<Confirm
isShow
title={t('list.delete.title', { ns: 'datasetDocuments' })}
isDisabled={doingAction}
onCancel={closeConfirm}
onConfirm={handleConfirm}
/>
)
}
{
!!editValues && (
<ApiKeyModal
pluginPayload={pluginPayload}
editValues={editValues}
onClose={() => {
setEditValues(null)
pendingOperationCredentialId.current = null
}}
onRemove={handleRemove}
disabled={disabled || doingAction}
onUpdate={onUpdate}
/>
)
}
<AuthorizedModals
pluginPayload={pluginPayload}
deleteCredentialId={deleteCredentialId}
doingAction={doingAction}
onDeleteConfirm={handleDeleteConfirm}
onDeleteCancel={closeDeleteConfirm}
editValues={editValues}
disabled={disabled}
onEditClose={closeEditModal}
onRemove={handleRemoveFromEdit}
onUpdate={onUpdate}
/>
</>
)
}

View File

@ -7,6 +7,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
// ==================== Imports (after mocks) ====================
import { MCPToolAvailabilityProvider } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import MultipleToolSelector from './index'
// ==================== Mock Setup ====================
@ -190,10 +191,11 @@ type RenderOptions = {
nodeOutputVars?: NodeOutPutVar[]
availableNodes?: Node[]
nodeId?: string
canChooseMCPTool?: boolean
versionSupported?: boolean
}
const renderComponent = (options: RenderOptions = {}) => {
const { versionSupported, ...overrides } = options
const defaultProps = {
disabled: false,
value: [],
@ -206,16 +208,17 @@ const renderComponent = (options: RenderOptions = {}) => {
nodeOutputVars: [createNodeOutputVar()],
availableNodes: [createNode()],
nodeId: 'test-node-id',
canChooseMCPTool: false,
}
const props = { ...defaultProps, ...options }
const props = { ...defaultProps, ...overrides }
const queryClient = createQueryClient()
return {
...render(
<QueryClientProvider client={queryClient}>
<MultipleToolSelector {...props} />
<MCPToolAvailabilityProvider versionSupported={versionSupported}>
<MultipleToolSelector {...props} />
</MCPToolAvailabilityProvider>
</QueryClientProvider>,
),
props,
@ -410,7 +413,7 @@ describe('MultipleToolSelector', () => {
expect(screen.getByText('2/3')).toBeInTheDocument()
})
it('should track enabled count with MCP tools when canChooseMCPTool is true', () => {
it('should track enabled count with MCP tools when version is supported', () => {
// Arrange
const mcpTools = [createMCPTool({ id: 'mcp-provider' })]
mockMCPToolsData.mockReturnValue(mcpTools)
@ -421,13 +424,13 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: true })
renderComponent({ value: tools, versionSupported: true })
// Assert
expect(screen.getByText('2/2')).toBeInTheDocument()
})
it('should not count MCP tools when canChooseMCPTool is false', () => {
it('should not count MCP tools when version is unsupported', () => {
// Arrange
const mcpTools = [createMCPTool({ id: 'mcp-provider' })]
mockMCPToolsData.mockReturnValue(mcpTools)
@ -438,7 +441,7 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: false })
renderComponent({ value: tools, versionSupported: false })
// Assert
expect(screen.getByText('1/2')).toBeInTheDocument()
@ -721,14 +724,6 @@ describe('MultipleToolSelector', () => {
expect(screen.getByTestId('tool-selector-add')).toBeInTheDocument()
})
it('should pass canChooseMCPTool prop correctly', () => {
// Arrange & Act
renderComponent({ canChooseMCPTool: true })
// Assert
expect(screen.getByTestId('tool-selector-add')).toBeInTheDocument()
})
it('should render with supportEnableSwitch for edit selectors', () => {
// Arrange
const tools = [createToolValue()]
@ -771,13 +766,13 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: true })
renderComponent({ value: tools, versionSupported: true })
// Assert
expect(screen.getByText('2/2')).toBeInTheDocument()
})
it('should exclude MCP tools from enabled count when canChooseMCPTool is false', () => {
it('should exclude MCP tools from enabled count when strategy version is unsupported', () => {
// Arrange
const mcpTools = [createMCPTool({ id: 'mcp-provider' })]
mockMCPToolsData.mockReturnValue(mcpTools)
@ -788,7 +783,7 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: false })
renderComponent({ value: tools, versionSupported: false })
// Assert - Only regular tool should be counted
expect(screen.getByText('1/2')).toBeInTheDocument()

View File

@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general'
import Tooltip from '@/app/components/base/tooltip'
import ToolSelector from '@/app/components/plugins/plugin-detail-panel/tool-selector'
import { useMCPToolAvailability } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import { useAllMCPTools } from '@/service/use-tools'
import { cn } from '@/utils/classnames'
@ -27,7 +28,6 @@ type Props = {
nodeOutputVars: NodeOutPutVar[]
availableNodes: Node[]
nodeId?: string
canChooseMCPTool?: boolean
}
const MultipleToolSelector = ({
@ -42,14 +42,14 @@ const MultipleToolSelector = ({
nodeOutputVars,
availableNodes,
nodeId,
canChooseMCPTool,
}: Props) => {
const { t } = useTranslation()
const { allowed: isMCPToolAllowed } = useMCPToolAvailability()
const { data: mcpTools } = useAllMCPTools()
const enabledCount = value.filter((item) => {
const isMCPTool = mcpTools?.find(tool => tool.id === item.provider_name)
if (isMCPTool)
return item.enabled && canChooseMCPTool
return item.enabled && isMCPToolAllowed
return item.enabled
}).length
// collapse control
@ -167,7 +167,6 @@ const MultipleToolSelector = ({
onSelectMultiple={handleAddMultiple}
onDelete={() => handleDelete(index)}
supportEnableSwitch
canChooseMCPTool={canChooseMCPTool}
isEdit
/>
</div>
@ -190,7 +189,6 @@ const MultipleToolSelector = ({
panelShowState={panelShowState}
onPanelShowStateChange={setPanelShowState}
isEdit={false}
canChooseMCPTool={canChooseMCPTool}
onSelectMultiple={handleAddMultiple}
/>
</>

View File

@ -64,7 +64,6 @@ type Props = {
nodeOutputVars: NodeOutPutVar[]
availableNodes: Node[]
nodeId?: string
canChooseMCPTool?: boolean
}
const ToolSelector: FC<Props> = ({
value,
@ -86,7 +85,6 @@ const ToolSelector: FC<Props> = ({
nodeOutputVars,
availableNodes,
nodeId = '',
canChooseMCPTool,
}) => {
const { t } = useTranslation()
const [isShow, onShowChange] = useState(false)
@ -267,7 +265,6 @@ const ToolSelector: FC<Props> = ({
</p>
</div>
)}
canChooseMCPTool={canChooseMCPTool}
/>
)}
</PortalToFollowElemTrigger>
@ -300,7 +297,6 @@ const ToolSelector: FC<Props> = ({
onSelectMultiple={handleSelectMultipleTool}
scope={scope}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
</div>
<div className="flex flex-col gap-1">

View File

@ -16,6 +16,7 @@ import Tooltip from '@/app/components/base/tooltip'
import { ToolTipContent } from '@/app/components/base/tooltip/content'
import Indicator from '@/app/components/header/indicator'
import { InstallPluginButton } from '@/app/components/workflow/nodes/_base/components/install-plugin-button'
import { useMCPToolAvailability } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import McpToolNotSupportTooltip from '@/app/components/workflow/nodes/_base/components/mcp-tool-not-support-tooltip'
import { SwitchPluginVersion } from '@/app/components/workflow/nodes/_base/components/switch-plugin-version'
import { cn } from '@/utils/classnames'
@ -39,7 +40,6 @@ type Props = {
versionMismatch?: boolean
open: boolean
authRemoved?: boolean
canChooseMCPTool?: boolean
}
const ToolItem = ({
@ -61,13 +61,13 @@ const ToolItem = ({
errorTip,
versionMismatch,
authRemoved,
canChooseMCPTool,
}: Props) => {
const { t } = useTranslation()
const { allowed: isMCPToolAllowed } = useMCPToolAvailability()
const providerNameText = isMCPTool ? providerShowName : providerName?.split('/').pop()
const isTransparent = uninstalled || versionMismatch || isError
const [isDeleting, setIsDeleting] = useState(false)
const isShowCanNotChooseMCPTip = isMCPTool && !canChooseMCPTool
const isShowCanNotChooseMCPTip = isMCPTool && !isMCPToolAllowed
return (
<div className={cn(
@ -125,9 +125,7 @@ const ToolItem = ({
/>
</div>
)}
{isShowCanNotChooseMCPTip && (
<McpToolNotSupportTooltip />
)}
{isShowCanNotChooseMCPTip && <McpToolNotSupportTooltip />}
{!isError && !uninstalled && !versionMismatch && noAuth && (
<Button variant="secondary" size="small">
{t('notAuthorized', { ns: 'tools' })}

View File

@ -47,7 +47,6 @@ type AllToolsProps = {
canNotSelectMultiple?: boolean
onSelectMultiple?: (type: BlockEnum, tools: ToolDefaultValue[]) => void
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
onTagsChange?: Dispatch<SetStateAction<string[]>>
isInRAGPipeline?: boolean
featuredPlugins?: Plugin[]
@ -71,7 +70,6 @@ const AllTools = ({
customTools,
mcpTools = [],
selectedTools,
canChooseMCPTool,
onTagsChange,
isInRAGPipeline = false,
featuredPlugins = [],
@ -249,7 +247,6 @@ const AllTools = ({
providerMap={providerMap}
onSelect={onSelect}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
isLoading={featuredLoading}
onInstallSuccess={async () => {
await onFeaturedInstallSuccess?.()
@ -275,7 +272,6 @@ const AllTools = ({
viewType={isSupportGroupView ? activeView : ViewType.flat}
hasSearchText={hasSearchText}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
</>
)}

View File

@ -30,7 +30,6 @@ type FeaturedToolsProps = {
providerMap: Map<string, ToolWithProvider>
onSelect: (type: BlockEnum, tool: ToolDefaultValue) => void
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
isLoading?: boolean
onInstallSuccess?: () => void
}
@ -42,7 +41,6 @@ const FeaturedTools = ({
providerMap,
onSelect,
selectedTools,
canChooseMCPTool,
isLoading = false,
onInstallSuccess,
}: FeaturedToolsProps) => {
@ -166,7 +164,6 @@ const FeaturedTools = ({
viewType={ViewType.flat}
hasSearchText={false}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
)}

View File

@ -223,7 +223,6 @@ const Tabs: FC<TabsProps> = ({
customTools={customTools || []}
workflowTools={workflowTools || []}
mcpTools={mcpTools || []}
canChooseMCPTool
onTagsChange={onTagsChange}
isInRAGPipeline={inRAGPipeline}
featuredPlugins={featuredPlugins}

View File

@ -50,7 +50,6 @@ type Props = {
supportAddCustomTool?: boolean
scope?: string
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
}
const ToolPicker: FC<Props> = ({
@ -66,7 +65,6 @@ const ToolPicker: FC<Props> = ({
scope = 'all',
selectedTools,
panelClassName,
canChooseMCPTool,
}) => {
const { t } = useTranslation()
const [searchText, setSearchText] = useState('')
@ -198,7 +196,6 @@ const ToolPicker: FC<Props> = ({
workflowTools={workflowToolList || []}
mcpTools={mcpTools || []}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
onTagsChange={setTags}
featuredPlugins={featuredPlugins}
featuredLoading={isFeaturedLoading}

View File

@ -18,7 +18,6 @@ type Props = {
letters: string[]
toolRefs: any
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
}
const ToolViewFlatView: FC<Props> = ({
@ -32,7 +31,6 @@ const ToolViewFlatView: FC<Props> = ({
onSelectMultiple,
toolRefs,
selectedTools,
canChooseMCPTool,
}) => {
const firstLetterToolIds = useMemo(() => {
const res: Record<string, string> = {}
@ -63,7 +61,6 @@ const ToolViewFlatView: FC<Props> = ({
canNotSelectMultiple={canNotSelectMultiple}
onSelectMultiple={onSelectMultiple}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
</div>
))}

View File

@ -14,7 +14,6 @@ type Props = {
canNotSelectMultiple?: boolean
onSelectMultiple?: (type: BlockEnum, tools: ToolDefaultValue[]) => void
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
}
const Item: FC<Props> = ({
@ -25,7 +24,6 @@ const Item: FC<Props> = ({
canNotSelectMultiple,
onSelectMultiple,
selectedTools,
canChooseMCPTool,
}) => {
return (
<div>
@ -43,7 +41,6 @@ const Item: FC<Props> = ({
canNotSelectMultiple={canNotSelectMultiple}
onSelectMultiple={onSelectMultiple}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
))}
</div>

View File

@ -15,7 +15,6 @@ type Props = {
canNotSelectMultiple?: boolean
onSelectMultiple?: (type: BlockEnum, tools: ToolDefaultValue[]) => void
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
}
const ToolListTreeView: FC<Props> = ({
@ -25,7 +24,6 @@ const ToolListTreeView: FC<Props> = ({
canNotSelectMultiple,
onSelectMultiple,
selectedTools,
canChooseMCPTool,
}) => {
const { t } = useTranslation()
const getI18nGroupName = useCallback((name: string) => {
@ -56,7 +54,6 @@ const ToolListTreeView: FC<Props> = ({
canNotSelectMultiple={canNotSelectMultiple}
onSelectMultiple={onSelectMultiple}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
))}
</div>

View File

@ -9,6 +9,7 @@ import * as React from 'react'
import { useCallback, useEffect, useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { Mcp } from '@/app/components/base/icons/src/vender/other'
import { useMCPToolAvailability } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
@ -38,7 +39,6 @@ type Props = {
canNotSelectMultiple?: boolean
onSelectMultiple?: (type: BlockEnum, tools: ToolDefaultValue[]) => void
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
isShowLetterIndex?: boolean
}
@ -51,9 +51,9 @@ const Tool: FC<Props> = ({
canNotSelectMultiple,
onSelectMultiple,
selectedTools,
canChooseMCPTool,
}) => {
const { t } = useTranslation()
const { allowed: isMCPToolAllowed } = useMCPToolAvailability()
const language = useGetLanguage()
const isFlatView = viewType === ViewType.flat
const notShowProvider = payload.type === CollectionType.workflow
@ -63,7 +63,7 @@ const Tool: FC<Props> = ({
const ref = useRef(null)
const isHovering = useHover(ref)
const isMCPTool = payload.type === CollectionType.mcp
const isShowCanNotChooseMCPTip = !canChooseMCPTool && isMCPTool
const isShowCanNotChooseMCPTip = !isMCPToolAllowed && isMCPTool
const { theme } = useTheme()
const normalizedIcon = useMemo<ToolWithProvider['icon']>(() => {
return normalizeProviderIcon(payload.icon) ?? payload.icon

View File

@ -21,7 +21,6 @@ type ToolsProps = {
className?: string
indexBarClassName?: string
selectedTools?: ToolValue[]
canChooseMCPTool?: boolean
}
const Tools = ({
onSelect,
@ -35,7 +34,6 @@ const Tools = ({
className,
indexBarClassName,
selectedTools,
canChooseMCPTool,
}: ToolsProps) => {
// const tools: any = []
const language = useGetLanguage()
@ -109,7 +107,6 @@ const Tools = ({
canNotSelectMultiple={canNotSelectMultiple}
onSelectMultiple={onSelectMultiple}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
indexBar={<IndexBar letters={letters} itemRefs={toolRefs} className={indexBarClassName} />}
/>
)
@ -121,7 +118,6 @@ const Tools = ({
canNotSelectMultiple={canNotSelectMultiple}
onSelectMultiple={onSelectMultiple}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
)
)}

View File

@ -1602,6 +1602,7 @@ export const useNodesInteractions = () => {
const offsetX = currentPosition.x - x
const offsetY = currentPosition.y - y
let idMapping: Record<string, string> = {}
const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = []
clipboardElements.forEach((nodeToPaste, index) => {
const nodeType = nodeToPaste.data.type
@ -1615,6 +1616,7 @@ export const useNodesInteractions = () => {
_isBundled: false,
_connectedSourceHandleIds: [],
_connectedTargetHandleIds: [],
_dimmed: false,
title: genNewNodeTitleFromOld(nodeToPaste.data.title),
},
position: {
@ -1682,27 +1684,24 @@ export const useNodesInteractions = () => {
return
// handle paste to nested block
if (selectedNode.data.type === BlockEnum.Iteration) {
newNode.data.isInIteration = true
newNode.data.iteration_id = selectedNode.data.iteration_id
newNode.parentId = selectedNode.id
newNode.positionAbsolute = {
x: newNode.position.x,
y: newNode.position.y,
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
}
else if (selectedNode.data.type === BlockEnum.Loop) {
newNode.data.isInLoop = true
newNode.data.loop_id = selectedNode.data.loop_id
if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) {
const isIteration = selectedNode.data.type === BlockEnum.Iteration
newNode.data.isInIteration = isIteration
newNode.data.iteration_id = isIteration ? selectedNode.id : undefined
newNode.data.isInLoop = !isIteration
newNode.data.loop_id = !isIteration ? selectedNode.id : undefined
newNode.parentId = selectedNode.id
newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
newNode.positionAbsolute = {
x: newNode.position.x,
y: newNode.position.y,
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
// update parent children array like native add
parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type })
}
}
}
@ -1733,7 +1732,17 @@ export const useNodesInteractions = () => {
}
})
setNodes([...nodes, ...nodesToPaste])
const newNodes = produce(nodes, (draft: Node[]) => {
parentChildrenToAppend.forEach(({ parentId, childId, childType }) => {
const p = draft.find(n => n.id === parentId)
if (p) {
p.data._children?.push({ nodeId: childId, nodeType: childType })
}
})
draft.push(...nodesToPaste)
})
setNodes(newNodes)
setEdges([...edges, ...edgesToPaste])
saveStateToHistory(WorkflowHistoryEvent.NodePaste, {
nodeId: nodesToPaste?.[0]?.id,

Some files were not shown because too many files have changed in this diff Show More