Compare commits

..

5 Commits

Author SHA1 Message Date
34a17c3ce6 fix(security): reject path traversal sequences before plugin daemon forward (GHSA-gvc6-fh3x-89xh) (#35796)
Co-authored-by: Ido Shani <ido@zafran.io>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-28 18:10:00 +08:00
18f083607b fix: normalize summary_index_setting None to fix preview error (#36626) 2026-05-28 18:07:59 +08:00
2fe8dbd7ca fix: fix cannot extract elements from a scalar (#36769) 2026-05-28 15:50:27 +08:00
80cd289e87 fix: replace .distinct() with .group_by(Conversation.id) for PostgreSQL JSON compatibility (#36610)
Co-authored-by: cocoon <kuishou68@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
(cherry picked from commit e617435d03)
2026-05-28 13:19:27 +08:00
a14bc8a371 fix: fix DocumentSegment.keywords can not a valid json (#36715) 2026-05-27 17:11:06 +08:00
6 changed files with 134 additions and 5 deletions

View File

@ -137,7 +137,7 @@ class CompletionConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
.group_by(Conversation.id)
)
elif args.annotation_status == "not_annotated":
query = (
@ -275,7 +275,7 @@ class ChatConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
.group_by(Conversation.id)
)
case "not_annotated":
query = (

View File

@ -3,7 +3,7 @@ import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import String, cast, func, or_, select
from sqlalchemy import String, case, cast, func, literal, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
@ -159,9 +159,17 @@ class DatasetDocumentSegmentListApi(Resource):
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
# Feed the set-returning function a JSON array in every row. Filtering in
# the subquery is not enough because PostgreSQL can still evaluate the
# SRF on scalar JSON before applying the predicate.
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
keywords_array = case(
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
else_=cast(literal("[]"), JSONB),
)
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
select(func.jsonb_array_elements_text(keywords_array))
.correlate(DocumentSegment)
.scalar_subquery()
),

View File

@ -3,6 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator
from typing import Any, cast
from urllib.parse import unquote
import httpx
from pydantic import BaseModel
@ -53,6 +54,9 @@ else:
logger = logging.getLogger(__name__)
PLUGIN_DAEMON_MAX_PATH_LENGTH = 4096
PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH = 8
_httpx_client: httpx.Client = get_pooled_http_client(
"plugin_daemon",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False),
@ -103,6 +107,20 @@ class BasePluginClient:
params: dict[str, Any] | None,
files: dict[str, Any] | None,
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
if len(path) > PLUGIN_DAEMON_MAX_PATH_LENGTH:
raise ValueError(f"Invalid plugin daemon path: path length exceeds {PLUGIN_DAEMON_MAX_PATH_LENGTH}")
decoded_path = path
for _ in range(PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH):
next_decoded_path = unquote(decoded_path)
if next_decoded_path == decoded_path:
break
decoded_path = next_decoded_path
else:
raise ValueError("Invalid plugin daemon path: path is too deeply encoded")
if any(seg == ".." for seg in decoded_path.split("/")):
raise ValueError(f"Invalid plugin daemon path: traversal sequence detected in {path!r}")
url = plugin_daemon_inner_api_baseurl / path
prepared_headers = dict(headers or {})
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY

View File

@ -170,6 +170,16 @@ class _AutomaticProcessRule(BaseModel):
mode: Literal[ProcessRuleMode.AUTOMATIC]
summary_index_setting: _SummaryIndexSetting | None = None
@field_validator("summary_index_setting", mode="before")
@classmethod
def _normalize_summary_index_setting(cls, v: Any) -> Any:
"""Treat dicts with enable=None (or missing enable) as None (#36602)."""
if v is None:
return None
if isinstance(v, dict) and v.get("enable") is None:
return None
return v
class _CustomProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
@ -178,6 +188,16 @@ class _CustomProcessRule(BaseModel):
rules: _EstimateRules
summary_index_setting: _SummaryIndexSetting | None = None
@field_validator("summary_index_setting", mode="before")
@classmethod
def _normalize_summary_index_setting(cls, v: Any) -> Any:
"""Treat dicts with enable=None (or missing enable) as None (#36602)."""
if v is None:
return None
if isinstance(v, dict) and v.get("enable") is None:
return None
return v
class _HierarchicalProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
@ -186,6 +206,16 @@ class _HierarchicalProcessRule(BaseModel):
rules: _EstimateRules
summary_index_setting: _SummaryIndexSetting | None = None
@field_validator("summary_index_setting", mode="before")
@classmethod
def _normalize_summary_index_setting(cls, v: Any) -> Any:
"""Treat dicts with enable=None (or missing enable) as None (#36602)."""
if v is None:
return None
if isinstance(v, dict) and v.get("enable") is None:
return None
return v
_EstimateProcessRule = Annotated[
_AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule,

View File

@ -1036,6 +1036,48 @@ class TestSegmentListAdvancedCases:
assert status == 200
assert response["total"] == 1
def test_segment_list_postgres_keyword_filter_handles_scalar_keywords(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
dataset = MagicMock()
document = MagicMock()
pagination = MagicMock(items=[], total=0, pages=0)
with (
app.test_request_context("/?keyword=test"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "11111111-1111-1111-1111-111111111111"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets_segments.DocumentService.get_document",
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.dify_config",
SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME="postgresql"),
),
patch(
"controllers.console.datasets.datasets_segments.db.paginate",
return_value=pagination,
) as paginate_mock,
):
method(api, "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333")
query = paginate_mock.call_args.kwargs["select"]
sql = str(query.compile(compile_kwargs={"literal_binds": True}))
assert "jsonb_array_elements_text(CASE" in sql
assert "ELSE CAST('[]' AS JSONB)" in sql
def test_segment_list_permission_denied(self, app: Flask):
"""Test segment list with permission denied"""
api = DatasetDocumentSegmentListApi()

View File

@ -1,11 +1,12 @@
import json
from urllib.parse import quote
import pytest
from pytest_mock import MockerFixture
from core.plugin.endpoint.exc import EndpointSetupFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.base import PLUGIN_DAEMON_MAX_PATH_LENGTH, BasePluginClient
from core.trigger.errors import (
EventIgnoreError,
TriggerInvokeError,
@ -67,6 +68,36 @@ class TestBasePluginClientImpl:
assert result == ["hello", "world"]
assert stream_mock.call_args.kwargs["data"] == {"k": "v"}
@pytest.mark.parametrize(
"path",
[
"plugin/tenant/%252e%252e%252ftarget",
"plugin/tenant/%2e%2e%252ftarget",
],
)
def test_prepare_request_rejects_encoded_traversal_with_encoded_separator(self, path: str):
client = BasePluginClient()
with pytest.raises(ValueError, match="traversal sequence detected"):
client._prepare_request(path, None, None, None, None)
def test_prepare_request_rejects_path_exceeding_max_length(self):
client = BasePluginClient()
path = "a" * (PLUGIN_DAEMON_MAX_PATH_LENGTH + 1)
with pytest.raises(ValueError, match="path length exceeds"):
client._prepare_request(path, None, None, None, None)
def test_prepare_request_rejects_excessively_encoded_path(self):
client = BasePluginClient()
segment = "..%2Ftarget"
for _ in range(9):
segment = quote(segment, safe="")
path = f"plugin/tenant/{segment}"
with pytest.raises(ValueError, match="too deeply encoded"):
client._prepare_request(path, None, None, None, None)
def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker: MockerFixture):
client = BasePluginClient()
mocker.patch.object(client, "_request", side_effect=RuntimeError("boom"))