From 2fe8dbd7ca77cc1d429d5c1d06c1124b0ee95189 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 28 May 2026 15:31:36 +0800 Subject: [PATCH] fix: fix cannot extract elements from a scalar (#36769) --- .../console/datasets/datasets_segments.py | 15 ++++--- .../datasets/test_datasets_segments.py | 42 +++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index df359938de..e98878db4e 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,7 +3,7 @@ import uuid from flask import request from flask_restx import Resource, marshal from pydantic import BaseModel, Field -from sqlalchemy import String, cast, func, or_, select +from sqlalchemy import String, case, cast, func, literal, or_, select from sqlalchemy.dialects.postgresql import JSONB from werkzeug.exceptions import Forbidden, NotFound @@ -159,12 +159,17 @@ class DatasetDocumentSegmentListApi(Resource): # Use database-specific methods for JSON array search if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text - # Guard with jsonb_typeof to avoid "cannot extract elements from a scalar" error - # when keywords is null or a non-array JSON value. + # Feed the set-returning function a JSON array in every row. Filtering in + # the subquery is not enough because PostgreSQL can still evaluate the + # SRF on scalar JSON before applying the predicate. + keywords_jsonb = cast(DocumentSegment.keywords, JSONB) + keywords_array = case( + (func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb), + else_=cast(literal("[]"), JSONB), + ) keywords_condition = func.array_to_string( func.array( - select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB))) - .where(func.jsonb_typeof(cast(DocumentSegment.keywords, JSONB)) == "array") + select(func.jsonb_array_elements_text(keywords_array)) .correlate(DocumentSegment) .scalar_subquery() ), diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 66d257ee66..c89ac69b68 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -1036,6 +1036,48 @@ class TestSegmentListAdvancedCases: assert status == 200 assert response["total"] == 1 + def test_segment_list_postgres_keyword_filter_handles_scalar_keywords(self, app: Flask): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + pagination = MagicMock(items=[], total=0, pages=0) + + with ( + app.test_request_context("/?keyword=test"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "11111111-1111-1111-1111-111111111111"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.dify_config", + SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME="postgresql"), + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ) as paginate_mock, + ): + method(api, "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333") + + query = paginate_mock.call_args.kwargs["select"] + sql = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "jsonb_array_elements_text(CASE" in sql + assert "ELSE CAST('[]' AS JSONB)" in sql + def test_segment_list_permission_denied(self, app: Flask): """Test segment list with permission denied""" api = DatasetDocumentSegmentListApi()