mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 21:03:22 +08:00
Compare commits
5 Commits
1.14.2
...
hotfix/1.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 34a17c3ce6 | |||
| 18f083607b | |||
| 2fe8dbd7ca | |||
| 80cd289e87 | |||
| a14bc8a371 |
@ -137,7 +137,7 @@ class CompletionConversationApi(Resource):
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
.group_by(Conversation.id)
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
@ -275,7 +275,7 @@ class ChatConversationApi(Resource):
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
.group_by(Conversation.id)
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
|
||||
@ -3,7 +3,7 @@ import uuid
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy import String, case, cast, func, literal, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -159,9 +159,17 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
# Feed the set-returning function a JSON array in every row. Filtering in
|
||||
# the subquery is not enough because PostgreSQL can still evaluate the
|
||||
# SRF on scalar JSON before applying the predicate.
|
||||
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
|
||||
keywords_array = case(
|
||||
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
|
||||
else_=cast(literal("[]"), JSONB),
|
||||
)
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
select(func.jsonb_array_elements_text(keywords_array))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
|
||||
@ -3,6 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
@ -53,6 +54,9 @@ else:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGIN_DAEMON_MAX_PATH_LENGTH = 4096
|
||||
PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH = 8
|
||||
|
||||
_httpx_client: httpx.Client = get_pooled_http_client(
|
||||
"plugin_daemon",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False),
|
||||
@ -103,6 +107,20 @@ class BasePluginClient:
|
||||
params: dict[str, Any] | None,
|
||||
files: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
|
||||
if len(path) > PLUGIN_DAEMON_MAX_PATH_LENGTH:
|
||||
raise ValueError(f"Invalid plugin daemon path: path length exceeds {PLUGIN_DAEMON_MAX_PATH_LENGTH}")
|
||||
|
||||
decoded_path = path
|
||||
for _ in range(PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH):
|
||||
next_decoded_path = unquote(decoded_path)
|
||||
if next_decoded_path == decoded_path:
|
||||
break
|
||||
decoded_path = next_decoded_path
|
||||
else:
|
||||
raise ValueError("Invalid plugin daemon path: path is too deeply encoded")
|
||||
|
||||
if any(seg == ".." for seg in decoded_path.split("/")):
|
||||
raise ValueError(f"Invalid plugin daemon path: traversal sequence detected in {path!r}")
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
prepared_headers = dict(headers or {})
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
|
||||
@ -170,6 +170,16 @@ class _AutomaticProcessRule(BaseModel):
|
||||
mode: Literal[ProcessRuleMode.AUTOMATIC]
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
@field_validator("summary_index_setting", mode="before")
|
||||
@classmethod
|
||||
def _normalize_summary_index_setting(cls, v: Any) -> Any:
|
||||
"""Treat dicts with enable=None (or missing enable) as None (#36602)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict) and v.get("enable") is None:
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
class _CustomProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@ -178,6 +188,16 @@ class _CustomProcessRule(BaseModel):
|
||||
rules: _EstimateRules
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
@field_validator("summary_index_setting", mode="before")
|
||||
@classmethod
|
||||
def _normalize_summary_index_setting(cls, v: Any) -> Any:
|
||||
"""Treat dicts with enable=None (or missing enable) as None (#36602)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict) and v.get("enable") is None:
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
class _HierarchicalProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@ -186,6 +206,16 @@ class _HierarchicalProcessRule(BaseModel):
|
||||
rules: _EstimateRules
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
@field_validator("summary_index_setting", mode="before")
|
||||
@classmethod
|
||||
def _normalize_summary_index_setting(cls, v: Any) -> Any:
|
||||
"""Treat dicts with enable=None (or missing enable) as None (#36602)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict) and v.get("enable") is None:
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
_EstimateProcessRule = Annotated[
|
||||
_AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule,
|
||||
|
||||
@ -1036,6 +1036,48 @@ class TestSegmentListAdvancedCases:
|
||||
assert status == 200
|
||||
assert response["total"] == 1
|
||||
|
||||
def test_segment_list_postgres_keyword_filter_handles_scalar_keywords(self, app: Flask):
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
dataset = MagicMock()
|
||||
document = MagicMock()
|
||||
pagination = MagicMock(items=[], total=0, pages=0)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?keyword=test"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "11111111-1111-1111-1111-111111111111"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DocumentService.get_document",
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.dify_config",
|
||||
SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME="postgresql"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.paginate",
|
||||
return_value=pagination,
|
||||
) as paginate_mock,
|
||||
):
|
||||
method(api, "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333")
|
||||
|
||||
query = paginate_mock.call_args.kwargs["select"]
|
||||
sql = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "jsonb_array_elements_text(CASE" in sql
|
||||
assert "ELSE CAST('[]' AS JSONB)" in sql
|
||||
|
||||
def test_segment_list_permission_denied(self, app: Flask):
|
||||
"""Test segment list with permission denied"""
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import json
|
||||
from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.plugin.endpoint.exc import EndpointSetupFailedError
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.impl.base import PLUGIN_DAEMON_MAX_PATH_LENGTH, BasePluginClient
|
||||
from core.trigger.errors import (
|
||||
EventIgnoreError,
|
||||
TriggerInvokeError,
|
||||
@ -67,6 +68,36 @@ class TestBasePluginClientImpl:
|
||||
assert result == ["hello", "world"]
|
||||
assert stream_mock.call_args.kwargs["data"] == {"k": "v"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"plugin/tenant/%252e%252e%252ftarget",
|
||||
"plugin/tenant/%2e%2e%252ftarget",
|
||||
],
|
||||
)
|
||||
def test_prepare_request_rejects_encoded_traversal_with_encoded_separator(self, path: str):
|
||||
client = BasePluginClient()
|
||||
|
||||
with pytest.raises(ValueError, match="traversal sequence detected"):
|
||||
client._prepare_request(path, None, None, None, None)
|
||||
|
||||
def test_prepare_request_rejects_path_exceeding_max_length(self):
|
||||
client = BasePluginClient()
|
||||
path = "a" * (PLUGIN_DAEMON_MAX_PATH_LENGTH + 1)
|
||||
|
||||
with pytest.raises(ValueError, match="path length exceeds"):
|
||||
client._prepare_request(path, None, None, None, None)
|
||||
|
||||
def test_prepare_request_rejects_excessively_encoded_path(self):
|
||||
client = BasePluginClient()
|
||||
segment = "..%2Ftarget"
|
||||
for _ in range(9):
|
||||
segment = quote(segment, safe="")
|
||||
path = f"plugin/tenant/{segment}"
|
||||
|
||||
with pytest.raises(ValueError, match="too deeply encoded"):
|
||||
client._prepare_request(path, None, None, None, None)
|
||||
|
||||
def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker: MockerFixture):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(client, "_request", side_effect=RuntimeError("boom"))
|
||||
|
||||
Reference in New Issue
Block a user