Files
ragflow/rag/utils/ob_conn.py
orbisai0security e992fe39b2 fix: the oceanbase database connector constructs sql... in ob_conn.py (#14470)
## Summary
Fix critical severity security issue in `rag/utils/ob_conn.py`.

## Vulnerability
| Field | Value |
|-------|-------|
| **ID** | V-003 |
| **Severity** | CRITICAL |
| **Scanner** | multi_agent_ai |
| **Rule** | `V-003` |
| **File** | `rag/utils/ob_conn.py:691` |

**Description**: The OceanBase database connector constructs SQL WHERE
clauses by directly embedding user-controlled filter expressions using
Python f-strings at lines 726, 777, 781, 787, 793, 821, and 827. No
parameterization or allowlist validation is applied before the
expressions are incorporated into live SQL queries. This is the most
critical vulnerability in the codebase because it directly exposes the
RAG knowledge base — the platform's core business asset — to complete
compromise.

## Changes
- `rag/utils/ob_conn.py`

## Verification
- [x] Build passes
- [x] Scanner re-scan confirms fix
- [x] LLM code review passed

---
*Automated security fix by [OrbisAI Security](https://orbisappsec.com)*
2026-04-30 14:25:17 +08:00

1466 lines
62 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import re
import time
from typing import Any, Optional
import numpy as np
from elasticsearch_dsl import Q, Search
from pydantic import BaseModel
from pymysql.converters import escape_string
from pyobvector import ARRAY
from sqlalchemy import Column, String, Integer, JSON, Double, Row
from sqlalchemy.dialects.mysql import LONGTEXT, TEXT
from sqlalchemy.sql.type_api import TypeEngine
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
from common.doc_store.ob_conn_base import (
OBConnectionBase, get_value_str,
vector_search_template, vector_column_pattern,
fulltext_index_name_template, doc_meta_column_names,
doc_meta_column_types,
)
from common.float_utils import get_float
from rag.nlp import rag_tokenizer
logger = logging.getLogger('ragflow.ob_conn')
column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence")
column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval")
column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id")
column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data")
column_definitions: list[Column] = [
Column("id", String(256), primary_key=True, comment="chunk id"),
Column("kb_id", String(256), nullable=False, index=True, comment="knowledge base id"),
Column("doc_id", String(256), nullable=True, index=True, comment="document id"),
Column("docnm_kwd", String(256), nullable=True, comment="document name"),
Column("doc_type_kwd", String(256), nullable=True, comment="document type"),
Column("title_tks", String(256), nullable=True, comment="title tokens"),
Column("title_sm_tks", String(256), nullable=True, comment="fine-grained (small) title tokens"),
Column("content_with_weight", LONGTEXT, nullable=True, comment="the original content"),
Column("content_ltks", LONGTEXT, nullable=True, comment="long text tokens derived from content_with_weight"),
Column("content_sm_ltks", LONGTEXT, nullable=True, comment="fine-grained (small) tokens derived from content_ltks"),
Column("pagerank_fea", Integer, nullable=True, comment="page rank priority, usually set in kb level"),
Column("important_kwd", ARRAY(String(256)), nullable=True, comment="keywords"),
Column("important_tks", TEXT, nullable=True, comment="keyword tokens"),
Column("question_kwd", ARRAY(String(1024)), nullable=True, comment="questions"),
Column("question_tks", TEXT, nullable=True, comment="question tokens"),
Column("tag_kwd", ARRAY(String(256)), nullable=True, comment="tags"),
Column("tag_feas", JSON, nullable=True,
comment="tag features used for 'rank_feature', format: [tag -> relevance score]"),
Column("available_int", Integer, nullable=False, index=True, server_default="1",
comment="status of availability, 0 for unavailable, 1 for available"),
Column("create_time", String(19), nullable=True, comment="creation time in YYYY-MM-DD HH:MM:SS format"),
Column("create_timestamp_flt", Double, nullable=True, comment="creation timestamp in float format"),
Column("img_id", String(128), nullable=True, comment="image id"),
Column("position_int", ARRAY(ARRAY(Integer)), nullable=True, comment="position"),
Column("page_num_int", ARRAY(Integer), nullable=True, comment="page number"),
Column("top_int", ARRAY(Integer), nullable=True, comment="rank from the top"),
Column("knowledge_graph_kwd", String(256), nullable=True, index=True, comment="knowledge graph chunk type"),
Column("source_id", ARRAY(String(256)), nullable=True, comment="source document id"),
Column("entity_kwd", String(256), nullable=True, comment="entity name"),
Column("entity_type_kwd", String(256), nullable=True, index=True, comment="entity type"),
Column("from_entity_kwd", String(256), nullable=True, comment="the source entity of this edge"),
Column("to_entity_kwd", String(256), nullable=True, comment="the target entity of this edge"),
Column("weight_int", Integer, nullable=True, comment="the weight of this edge"),
Column("weight_flt", Double, nullable=True, comment="the weight of community report"),
Column("entities_kwd", ARRAY(String(256)), nullable=True, comment="node ids of entities"),
Column("rank_flt", Double, nullable=True, comment="rank of this entity"),
Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'",
comment="whether it has been deleted"),
column_chunk_data,
Column("metadata", JSON, nullable=True, comment="metadata for this chunk"),
Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"),
column_order_id,
column_group_id,
column_mom_id,
]
column_names: list[str] = [col.name for col in column_definitions]
column_types: dict[str, TypeEngine] = {col.name: col.type for col in column_definitions}
array_columns: list[str] = [col.name for col in column_definitions if isinstance(col.type, ARRAY)]
# Index columns for RAG chunk table
INDEX_COLUMNS: list[str] = [
"kb_id",
"doc_id",
"available_int",
"knowledge_graph_kwd",
"entity_type_kwd",
"removed_kwd",
]
# Full-text search columns (with weight) - original content
FTS_COLUMNS_ORIGIN: list[str] = [
"docnm_kwd^10",
"content_with_weight",
"important_tks^20",
"question_tks^20",
]
# Full-text search columns (with weight) - tokenized content
FTS_COLUMNS_TKS: list[str] = [
"title_tks^10",
"title_sm_tks^5",
"important_tks^20",
"question_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
# Extra columns to add after table creation (for migration)
EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id, column_chunk_data]
class SearchResult(BaseModel):
total: int
chunks: list[dict]
def get_column_value(column_name: str, value: Any) -> Any:
# Check chunk table columns first, then doc_meta table columns
column_type = column_types.get(column_name) or doc_meta_column_types.get(column_name)
if column_type:
if isinstance(column_type, String):
return str(value)
elif isinstance(column_type, Integer):
return int(value)
elif isinstance(column_type, Double):
return float(value)
elif isinstance(column_type, ARRAY) or isinstance(column_type, JSON):
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
return value
else:
raise ValueError(f"Unsupported column type for column '{column_name}': {column_type}")
elif vector_column_pattern.match(column_name):
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
return value
elif column_name == "_score":
return float(value)
else:
raise ValueError(f"Unknown column '{column_name}' with value '{value}'.")
def get_default_value(column_name: str) -> Any:
if column_name == "available_int":
return 1
elif column_name == "removed_kwd":
return "N"
elif column_name == "_order_id":
return 0
else:
return None
def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str:
"""
Convert metadata filtering conditions to MySQL JSON path expression.
Args:
metadata_filtering_conditions: dict with 'conditions' and 'logical_operator' keys
Returns:
MySQL JSON path expression string
"""
if not metadata_filtering_conditions:
return ""
conditions = metadata_filtering_conditions.get("conditions", [])
logical_operator = metadata_filtering_conditions.get("logical_operator", "and").upper()
if not conditions:
return ""
if logical_operator not in ["AND", "OR"]:
raise ValueError(f"Unsupported logical operator: {logical_operator}. Only 'and' and 'or' are supported.")
metadata_filters = []
for condition in conditions:
name = condition.get("name")
comparison_operator = condition.get("comparison_operator")
value = condition.get("value")
if not all([name, comparison_operator]):
continue
expr = f"JSON_EXTRACT(metadata, '$.{name}')"
value_str = get_value_str(value)
# Convert comparison operator to MySQL JSON path syntax
if comparison_operator == "is":
# JSON_EXTRACT(metadata, '$.field_name') = 'value'
metadata_filters.append(f"{expr} = {value_str}")
elif comparison_operator == "is not":
metadata_filters.append(f"{expr} != {value_str}")
elif comparison_operator == "contains":
metadata_filters.append(f"JSON_CONTAINS({expr}, {value_str})")
elif comparison_operator == "not contains":
metadata_filters.append(f"NOT JSON_CONTAINS({expr}, {value_str})")
elif comparison_operator == "start with":
metadata_filters.append(f"{expr} LIKE CONCAT({value_str}, '%')")
elif comparison_operator == "end with":
metadata_filters.append(f"{expr} LIKE CONCAT('%', {value_str})")
elif comparison_operator == "empty":
metadata_filters.append(f"({expr} IS NULL OR {expr} = '' OR {expr} = '[]' OR {expr} = '{{}}')")
elif comparison_operator == "not empty":
metadata_filters.append(f"({expr} IS NOT NULL AND {expr} != '' AND {expr} != '[]' AND {expr} != '{{}}')")
# Number operators
elif comparison_operator == "=":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) = {value_str}")
elif comparison_operator == "":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) != {value_str}")
elif comparison_operator == ">":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) > {value_str}")
elif comparison_operator == "<":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) < {value_str}")
elif comparison_operator == "":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) >= {value_str}")
elif comparison_operator == "":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) <= {value_str}")
# Time operators
elif comparison_operator == "before":
metadata_filters.append(f"CAST({expr} AS DATETIME) < {value_str}")
elif comparison_operator == "after":
metadata_filters.append(f"CAST({expr} AS DATETIME) > {value_str}")
else:
logger.warning(f"Unsupported comparison operator: {comparison_operator}")
continue
if not metadata_filters:
return ""
return f"({f' {logical_operator} '.join(metadata_filters)})"
_VALID_FILTER_COLUMNS: set[str] = set(column_names) | set(doc_meta_column_names)
def get_filters(condition: dict) -> list[str]:
filters: list[str] = []
for k, v in condition.items():
if not v:
continue
if k == "exists":
if isinstance(v, str) and v in _VALID_FILTER_COLUMNS:
filters.append(f"{v} IS NOT NULL")
elif k == "must_not" and isinstance(v, dict) and "exists" in v:
col = v.get("exists")
if isinstance(col, str) and col in _VALID_FILTER_COLUMNS:
filters.append(f"{col} IS NULL")
elif k == "metadata_filtering_conditions":
# Handle metadata filtering conditions
metadata_filter = get_metadata_filter_expression(v)
if metadata_filter:
filters.append(metadata_filter)
elif k in array_columns:
if isinstance(v, list):
array_filters = []
for vv in v:
array_filters.append(f"array_contains({k}, {get_value_str(vv)})")
array_filter = " OR ".join(array_filters)
filters.append(f"({array_filter})")
else:
filters.append(f"array_contains({k}, {get_value_str(v)})")
elif k in _VALID_FILTER_COLUMNS:
if isinstance(v, list):
values: list[str] = []
for item in v:
values.append(get_value_str(item))
value = ", ".join(values)
filters.append(f"{k} IN ({value})")
else:
filters.append(f"{k} = {get_value_str(v)}")
return filters
@singleton
class OBConnection(OBConnectionBase):
def __init__(self):
super().__init__(logger_name='ragflow.ob_conn')
# Determine which columns to use for full-text search dynamically
self._fulltext_search_columns = FTS_COLUMNS_ORIGIN if self.search_original_content else FTS_COLUMNS_TKS
"""
Template method implementations
"""
def get_index_columns(self) -> list[str]:
return INDEX_COLUMNS
def get_column_definitions(self) -> list[Column]:
return column_definitions
def get_extra_columns(self) -> list[Column]:
return EXTRA_COLUMNS
def get_lock_prefix(self) -> str:
return "ob_"
def _get_filters(self, condition: dict) -> list[str]:
return get_filters(condition)
def get_fulltext_columns(self) -> list[str]:
"""Return list of column names that need fulltext indexes (without weight suffix)."""
return [col.split("^")[0] for col in self._fulltext_search_columns]
def delete_idx(self, index_name: str, dataset_id: str):
if dataset_id:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
super().delete_idx(index_name, dataset_id)
"""
Performance monitoring
"""
def get_performance_metrics(self) -> dict:
"""
Get comprehensive performance metrics for OceanBase.
Returns:
dict: Performance metrics including latency, storage, QPS, and slow queries
"""
metrics = {
"connection": "connected",
"latency_ms": 0.0,
"storage_used": "0B",
"storage_total": "0B",
"query_per_second": 0,
"slow_queries": 0,
"active_connections": 0,
"max_connections": 0
}
try:
# Measure connection latency
start_time = time.time()
self.client.perform_raw_text_sql("SELECT 1").fetchone()
metrics["latency_ms"] = round((time.time() - start_time) * 1000, 2)
# Get storage information
try:
storage_info = self._get_storage_info()
metrics.update(storage_info)
except Exception as e:
logger.warning(f"Failed to get storage info: {str(e)}")
# Get connection pool statistics
try:
pool_stats = self._get_connection_pool_stats()
metrics.update(pool_stats)
except Exception as e:
logger.warning(f"Failed to get connection pool stats: {str(e)}")
# Get slow query statistics
try:
slow_queries = self._get_slow_query_count()
metrics["slow_queries"] = slow_queries
except Exception as e:
logger.warning(f"Failed to get slow query count: {str(e)}")
# Get QPS (Queries Per Second) - approximate from processlist
try:
qps = self._estimate_qps()
metrics["query_per_second"] = qps
except Exception as e:
logger.warning(f"Failed to estimate QPS: {str(e)}")
except Exception as e:
metrics["connection"] = "disconnected"
metrics["error"] = str(e)
logger.error(f"Failed to get OceanBase performance metrics: {str(e)}")
return metrics
def _get_storage_info(self) -> dict:
"""
Get storage space usage information.
Returns:
dict: Storage information with used and total space
"""
try:
# Get database size
result = self.client.perform_raw_text_sql(
f"SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'size_mb' "
f"FROM information_schema.tables WHERE table_schema = '{self.db_name}'"
).fetchone()
size_mb = float(result[0]) if result and result[0] else 0.0
# Try to get total available space (may not be available in all OceanBase versions)
try:
result = self.client.perform_raw_text_sql(
"SELECT ROUND(SUM(total_size) / 1024 / 1024 / 1024, 2) AS 'total_gb' "
"FROM oceanbase.__all_disk_stat"
).fetchone()
total_gb = float(result[0]) if result and result[0] else None
except Exception:
# Fallback: estimate total space (100GB default if not available)
total_gb = 100.0
return {
"storage_used": f"{size_mb:.2f}MB",
"storage_total": f"{total_gb:.2f}GB" if total_gb else "N/A"
}
except Exception as e:
logger.warning(f"Failed to get storage info: {str(e)}")
return {
"storage_used": "N/A",
"storage_total": "N/A"
}
def _get_connection_pool_stats(self) -> dict:
"""
Get connection pool statistics.
Returns:
dict: Connection pool statistics
"""
try:
# Get active connections from processlist
result = self.client.perform_raw_text_sql("SHOW PROCESSLIST")
active_connections = len(list(result.fetchall()))
# Get max_connections setting
max_conn_result = self.client.perform_raw_text_sql(
"SHOW VARIABLES LIKE 'max_connections'"
).fetchone()
max_connections = int(max_conn_result[1]) if max_conn_result and max_conn_result[1] else 0
# Get pool size from client if available
pool_size = getattr(self.client, 'pool_size', None) or 0
return {
"active_connections": active_connections,
"max_connections": max_connections if max_connections > 0 else pool_size,
"pool_size": pool_size
}
except Exception as e:
logger.warning(f"Failed to get connection pool stats: {str(e)}")
return {
"active_connections": 0,
"max_connections": 0,
"pool_size": 0
}
def _get_slow_query_count(self, threshold_seconds: int = 1) -> int:
"""
Get count of slow queries (queries taking longer than threshold).
Args:
threshold_seconds: Threshold in seconds for slow queries (default: 1)
Returns:
int: Number of slow queries
"""
try:
result = self.client.perform_raw_text_sql(
f"SELECT COUNT(*) FROM information_schema.processlist "
f"WHERE time > {threshold_seconds} AND command != 'Sleep'"
).fetchone()
return int(result[0]) if result and result[0] else 0
except Exception as e:
logger.warning(f"Failed to get slow query count: {str(e)}")
return 0
def _estimate_qps(self) -> int:
"""
Estimate queries per second from processlist.
Returns:
int: Estimated queries per second
"""
try:
# Count active queries (non-Sleep commands)
result = self.client.perform_raw_text_sql(
"SELECT COUNT(*) FROM information_schema.processlist WHERE command != 'Sleep'"
).fetchone()
active_queries = int(result[0]) if result and result[0] else 0
# Rough estimate: assume average query takes 0.1 seconds
# This is a simplified estimation
estimated_qps = max(0, active_queries * 10)
return estimated_qps
except Exception as e:
logger.warning(f"Failed to estimate QPS: {str(e)}")
return 0
"""
CRUD operations
"""
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] = [],
rank_feature: dict | None = None,
**kwargs,
):
if isinstance(index_names, str):
index_names = index_names.split(",")
if not (isinstance(index_names, list) and len(index_names) > 0):
raise ValueError("index_names must be a non-empty list")
index_names = list(set(index_names))
if len(match_expressions) == 3:
if not self.enable_fulltext_search:
# disable fulltext search in fusion search, which means fallback to vector search
match_expressions = [m for m in match_expressions if isinstance(m, MatchDenseExpr)]
else:
for m in match_expressions:
if isinstance(m, FusionExpr):
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
# skip the search if its weight is zero
if vector_similarity_weight <= 0.0:
match_expressions = [m for m in match_expressions if isinstance(m, MatchTextExpr)]
elif vector_similarity_weight >= 1.0:
match_expressions = [m for m in match_expressions if isinstance(m, MatchDenseExpr)]
result: SearchResult = SearchResult(
total=0,
chunks=[],
)
# copied from es_conn.py
if len(match_expressions) == 3 and self.es:
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
if not (len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
match_expressions[1], MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)):
raise ValueError("match_expressions must contain MatchTextExpr, MatchDenseExpr, and FusionExpr")
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=FTS_COLUMNS_TKS,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
if bqry is None:
raise ValueError("bqry must not be None")
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
# for field in highlightFields:
# s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"OBConnection.hybrid_search {str(index_names)} query: " + json.dumps(q))
for index_name in index_names:
start_time = time.time()
res = self.es.search(index=index_name,
body=q,
timeout="600s",
track_total_hits=True,
_source=True)
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds,"
f" got count: {len(res)}"
)
for chunk in res:
result.chunks.append(self._es_row_to_entity(chunk))
result.total = result.total + 1
return result
output_fields = select_fields.copy()
if "*" in output_fields:
if index_names[0].startswith("ragflow_doc_meta_"):
output_fields = doc_meta_column_names.copy()
else:
output_fields = column_names.copy()
if "id" not in output_fields:
output_fields = ["id"] + output_fields
if "_score" in output_fields:
output_fields.remove("_score")
if highlight_fields:
for field in highlight_fields:
if field not in output_fields:
output_fields.append(field)
fields_expr = ", ".join(output_fields)
condition["kb_id"] = knowledgebase_ids
filters: list[str] = get_filters(condition)
filters_expr = " AND ".join(filters)
fulltext_query: Optional[str] = None
fulltext_topn: Optional[int] = None
fulltext_search_weight: dict[str, float] = {}
fulltext_search_expr: dict[str, str] = {}
fulltext_search_idx_list: list[str] = []
fulltext_search_score_expr: Optional[str] = None
fulltext_search_filter: Optional[str] = None
vector_column_name: Optional[str] = None
vector_data: Optional[list[float]] = None
vector_topn: Optional[int] = None
vector_similarity_threshold: Optional[float] = None
vector_similarity_weight: Optional[float] = None
vector_search_expr: Optional[str] = None
vector_search_score_expr: Optional[str] = None
vector_search_filter: Optional[str] = None
for m in match_expressions:
if isinstance(m, MatchTextExpr):
if "original_query" not in m.extra_options:
raise ValueError("'original_query' is missing in extra_options.")
fulltext_query = m.extra_options["original_query"]
fulltext_query = escape_string(fulltext_query.strip())
fulltext_topn = m.topn
fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns(
fulltext_query, self._fulltext_search_columns
)
for column_name in fulltext_search_expr.keys():
fulltext_search_idx_list.append(fulltext_index_name_template % column_name)
elif isinstance(m, MatchDenseExpr):
if m.embedding_data_type != "float":
raise ValueError(f"embedding data type '{m.embedding_data_type}' is not float.")
vector_column_name = m.vector_column_name
vector_data = m.embedding_data
vector_topn = m.topn
vector_similarity_threshold = float(m.extra_options.get("similarity", 0.0))
elif isinstance(m, FusionExpr):
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
if fulltext_query:
fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})"
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
if vector_data:
vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
# use (1 - cosine_distance) as score, which should be [-1, 1]
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
vector_search_score_expr = f"(1 - {vector_search_expr})"
vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}"
pagerank_score_expr = f"(CAST(IFNULL({PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
# TODO use tag rank_feature in sorting
# tag_rank_fea = {k: float(v) for k, v in (rank_feature or {}).items() if k != PAGERANK_FLD}
if fulltext_query and vector_data:
search_type = "fusion"
elif fulltext_query:
search_type = "fulltext"
elif vector_data:
search_type = "vector"
elif len(agg_fields) > 0:
search_type = "aggregation"
else:
search_type = "filter"
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
output_fields.append("_score")
if limit:
if vector_topn is not None:
limit = min(vector_topn, limit)
if fulltext_topn is not None:
limit = min(fulltext_topn, limit)
for index_name in index_names:
if not self._check_table_exists_cached(index_name):
continue
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
if search_type == "fusion":
# fusion search, usually for chat
num_candidates = vector_topn + fulltext_topn
if self.use_fulltext_first_fusion_search:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f")"
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
)
else:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} id FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY {fulltext_search_score_expr}"
f" LIMIT {fulltext_topn}"
f"),"
f"vector_results AS ("
f" SELECT id FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {vector_topn}"
f")"
f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id"
)
logger.debug("OBConnection.search with count sql: %s", count_sql)
rows, elapsed_time = self._execute_search_sql(count_sql)
total_count = rows[0][0] if rows else 0
result.total += total_count
logger.info(
f"OBConnection.search table {index_name}, search type: fusion, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" vector column: '{vector_column_name}',"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" got count: {total_count}"
)
if total_count == 0:
continue
if self.use_fulltext_first_fusion_search:
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f")"
f" SELECT {fields_expr}, {score_expr} AS _score"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}"
)
else:
pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})"
fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"])
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {fulltext_topn}"
f"),"
f"vector_results AS ("
f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {vector_topn}"
f"),"
f"combined_results AS ("
f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score"
f" FROM fulltext_results f"
f" FULL OUTER JOIN vector_results v"
f" ON f.id = v.id"
f")"
f" SELECT {fields_expr}, c.score as _score"
f" FROM combined_results c"
f" JOIN {index_name} t"
f" ON c.id = t.id"
f" ORDER BY score DESC"
f" LIMIT {offset}, {limit}"
)
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
rows, elapsed_time = self._execute_search_sql(fusion_sql)
logger.info(
f"OBConnection.search table {index_name}, search type: fusion, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" vector column: '{vector_column_name}',"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" vector_similarity_weight: {vector_similarity_weight},"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
elif search_type == "vector":
# vector search, usually used for graph search
count_sql = self._build_count_sql(index_name, filters_expr, vector_search_filter)
logger.debug("OBConnection.search with vector count sql: %s", count_sql)
rows, elapsed_time = self._execute_search_sql(count_sql)
total_count = rows[0][0] if rows else 0
result.total += total_count
logger.info(
f"OBConnection.search table {index_name}, search type: vector, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" vector column: '{vector_column_name}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" got count: {total_count}"
)
if total_count == 0:
continue
vector_sql = self._build_vector_search_sql(
index_name, fields_expr, vector_search_score_expr, filters_expr,
vector_search_filter, vector_search_expr, limit, vector_topn, offset
)
logger.debug("OBConnection.search with vector sql: %s", vector_sql)
rows, elapsed_time = self._execute_search_sql(vector_sql)
logger.info(
f"OBConnection.search table {index_name}, search type: vector, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" vector column: '{vector_column_name}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
elif search_type == "fulltext":
# fulltext search, usually used to search chunks in one dataset
count_sql = self._build_count_sql(index_name, filters_expr, fulltext_search_filter, fulltext_search_hint)
logger.debug("OBConnection.search with fulltext count sql: %s", count_sql)
rows, elapsed_time = self._execute_search_sql(count_sql)
total_count = rows[0][0] if rows else 0
result.total += total_count
logger.info(
f"OBConnection.search table {index_name}, search type: fulltext, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" got count: {total_count}"
)
if total_count == 0:
continue
fulltext_sql = self._build_fulltext_search_sql(
index_name, fields_expr, fulltext_search_score_expr, filters_expr,
fulltext_search_filter, offset, limit, fulltext_topn, fulltext_search_hint
)
logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql)
rows, elapsed_time = self._execute_search_sql(fulltext_sql)
logger.info(
f"OBConnection.search table {index_name}, search type: fulltext, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
elif search_type == "aggregation":
# aggregation search
if len(agg_fields) != 1:
raise ValueError("Only one aggregation field is supported in OceanBase.")
agg_field = agg_fields[0]
if agg_field in array_columns:
res = self.client.perform_raw_text_sql(
f"SELECT {agg_field} FROM {index_name}"
f" WHERE {agg_field} IS NOT NULL AND {filters_expr}"
)
counts = {}
for row in res:
if row[0]:
if isinstance(row[0], str):
try:
arr = json.loads(row[0])
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON array: {row[0]}")
continue
else:
arr = row[0]
if isinstance(arr, list):
for v in arr:
if isinstance(v, str) and v.strip():
counts[v] = counts.get(v, 0) + 1
for v, count in counts.items():
result.chunks.append({
"value": v,
"count": count,
})
result.total += len(counts)
else:
res = self.client.perform_raw_text_sql(
f"SELECT {agg_field}, COUNT(*) as count FROM {index_name}"
f" WHERE {agg_field} IS NOT NULL AND {filters_expr}"
f" GROUP BY {agg_field}"
)
for row in res:
result.chunks.append({
"value": row[0],
"count": int(row[1]),
})
result.total += 1
else:
# only filter
orders: list[str] = []
if order_by:
for field, order in order_by.fields:
if isinstance(column_types[field], ARRAY):
f = field + "_sort"
fields_expr += f", array_avg({field}) AS {f}"
field = f
order = "ASC" if order == 0 else "DESC"
orders.append(f"{field} {order}")
count_sql = self._build_count_sql(index_name, filters_expr)
logger.debug("OBConnection.search with normal count sql: %s", count_sql)
rows, elapsed_time = self._execute_search_sql(count_sql)
total_count = rows[0][0] if rows else 0
result.total += total_count
logger.info(
f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" condition: '{condition}',"
f" got count: {total_count}"
)
if total_count == 0:
continue
order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else ""
limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else ""
filter_sql = self._build_filter_search_sql(
index_name, fields_expr, filters_expr, order_by_expr, limit_expr
)
logger.debug("OBConnection.search with normal sql: %s", filter_sql)
rows, elapsed_time = self._execute_search_sql(filter_sql)
logger.info(
f"OBConnection.search table {index_name}, search type: normal, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" condition: '{condition}',"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
if result.total == 0:
result.total = len(result.chunks)
return result
def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
try:
doc = super().get(chunk_id, index_name, knowledgebase_ids)
if doc is None:
return None
return doc
except json.JSONDecodeError as e:
logger.error(f"JSON decode error when getting chunk {chunk_id}: {str(e)}")
return {
"id": chunk_id,
"error": f"Failed to parse chunk data due to invalid JSON: {str(e)}"
}
except Exception as e:
logger.exception(f"OBConnection.get({chunk_id}) got exception")
raise e
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
if not documents:
return []
# For doc_meta tables, use simple insert without field transformation
if index_name.startswith("ragflow_doc_meta_"):
return self._insert_doc_meta(documents, index_name)
docs: list[dict] = []
ids: list[str] = []
for document in documents:
d: dict = {}
for k, v in document.items():
if vector_column_pattern.match(k):
d[k] = v
continue
if k not in column_names:
if "extra" not in d:
d["extra"] = {}
d["extra"][k] = v
continue
if v is None:
d[k] = get_default_value(k)
continue
if k == "kb_id" and isinstance(v, list):
d[k] = v[0]
elif k == "content_with_weight" and isinstance(v, dict):
d[k] = json.dumps(v, ensure_ascii=False)
elif k == "position_int":
d[k] = json.dumps([list(vv) for vv in v], ensure_ascii=False)
elif isinstance(v, list):
# remove characters like '\t' for JSON dump and clean special characters
cleaned_v = []
for vv in v:
if isinstance(vv, str):
cleaned_str = vv.strip()
cleaned_str = cleaned_str.replace('\\', '\\\\')
cleaned_str = cleaned_str.replace('\n', '\\n')
cleaned_str = cleaned_str.replace('\r', '\\r')
cleaned_str = cleaned_str.replace('\t', '\\t')
cleaned_v.append(cleaned_str)
else:
cleaned_v.append(vv)
d[k] = json.dumps(cleaned_v, ensure_ascii=False)
else:
d[k] = v
ids.append(d["id"])
# this is to fix https://github.com/sqlalchemy/sqlalchemy/issues/9703
for column_name in column_names:
if column_name not in d:
d[column_name] = get_default_value(column_name)
metadata = d.get("metadata", {})
if metadata is None:
metadata = {}
group_id = metadata.get("_group_id")
title = metadata.get("_title")
if d.get("doc_id"):
if group_id:
d["group_id"] = group_id
else:
d["group_id"] = d["doc_id"]
if title:
d["docnm_kwd"] = title
docs.append(d)
logger.debug("OBConnection.insert chunks: %s", docs)
res = []
try:
self.client.upsert(index_name, docs)
except Exception as e:
logger.error(f"OBConnection.insert error: {str(e)}")
res.append(str(e))
return res
def _insert_doc_meta(self, documents: list[dict], index_name: str) -> list[str]:
"""Insert documents into doc_meta table with simple field handling."""
docs: list[dict] = []
for document in documents:
d = {
"id": document.get("id"),
"kb_id": document.get("kb_id"),
}
# Handle meta_fields - store as JSON
meta_fields = document.get("meta_fields")
if meta_fields is not None:
if isinstance(meta_fields, dict):
d["meta_fields"] = json.dumps(meta_fields, ensure_ascii=False)
elif isinstance(meta_fields, str):
d["meta_fields"] = meta_fields
else:
d["meta_fields"] = "{}"
else:
d["meta_fields"] = "{}"
docs.append(d)
logger.debug("OBConnection._insert_doc_meta: %s", docs)
res = []
try:
self.client.upsert(index_name, docs)
except Exception as e:
logger.error(f"OBConnection._insert_doc_meta error: {str(e)}")
res.append(str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
if not self._check_table_exists_cached(index_name):
return True
# For doc_meta tables, don't force kb_id in condition
if not index_name.startswith("ragflow_doc_meta_"):
condition["kb_id"] = knowledgebase_id
filters = get_filters(condition)
set_values: list[str] = []
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
set_values.append(f"{v} = NULL")
else:
if not isinstance(v, dict):
raise ValueError(f"Expected str or dict for 'remove', got {type(new_value[k])}.")
for kk, vv in v.items():
if kk not in array_columns:
raise ValueError(f"Column '{kk}' is not an array column.")
set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})")
elif k == "add":
if not isinstance(v, dict):
raise ValueError(f"Expected str or dict for 'add', got {type(new_value[k])}.")
for kk, vv in v.items():
if kk not in array_columns:
raise ValueError(f"Column '{kk}' is not an array column.")
set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})")
elif k == "metadata":
if not isinstance(v, dict):
raise ValueError(f"Expected dict for 'metadata', got {type(new_value[k])}")
set_values.append(f"{k} = {get_value_str(v)}")
if v and "doc_id" in condition:
group_id = v.get("_group_id")
title = v.get("_title")
if group_id:
set_values.append(f"group_id = {get_value_str(group_id)}")
if title:
set_values.append(f"docnm_kwd = {get_value_str(title)}")
else:
set_values.append(f"{k} = {get_value_str(v)}")
if not set_values:
return True
update_sql = (
f"UPDATE {index_name}"
f" SET {', '.join(set_values)}"
f" WHERE {' AND '.join(filters)}"
)
logger.debug("OBConnection.update sql: %s", update_sql)
try:
self.client.perform_raw_text_sql(update_sql)
return True
except Exception as e:
logger.error(f"OBConnection.update error: {str(e)}")
return False
def adjust_chunk_pagerank_fea(
self,
chunk_id: str,
index_name: str,
knowledgebase_id: str,
delta: int,
min_w: int = 0,
max_w: int = 100,
) -> bool:
"""Atomically adjust pagerank_fea on one chunk row (single UPDATE)."""
if not self._check_table_exists_cached(index_name):
return True
d = int(delta)
sql = (
f"UPDATE {index_name} SET {PAGERANK_FLD} = "
f"GREATEST({int(min_w)}, LEAST({int(max_w)}, COALESCE({PAGERANK_FLD}, 0) + ({d}))) "
f"WHERE id = {get_value_str(chunk_id)} AND kb_id = {get_value_str(knowledgebase_id)}"
)
logger.debug("OBConnection.adjust_chunk_pagerank_fea sql: %s", sql)
try:
self.client.perform_raw_text_sql(sql)
return True
except Exception as e:
logger.error("OBConnection.adjust_chunk_pagerank_fea error: %s", e)
return False
def _row_to_entity(self, data: Row, fields: list[str]) -> dict:
entity = {}
for i, field in enumerate(fields):
value = data[i]
if value is None:
continue
entity[field] = get_column_value(field, value)
return entity
@staticmethod
def _es_row_to_entity(data: dict) -> dict:
entity = {}
for k, v in data.items():
if v is None:
continue
entity[k] = get_column_value(k, v)
return entity
"""
Helper functions for search result
"""
def get_total(self, res) -> int:
return res.total
def get_doc_ids(self, res) -> list[str]:
return [row["id"] for row in res.chunks]
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
result = {}
for row in res.chunks:
data = {}
for field in fields:
v = row.get(field)
if v is not None:
data[field] = v
result[row["id"]] = data
return result
# copied from query.FulltextQueryer
def is_chinese(self, line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
def highlight(self, txt: str, tks: str, question: str, keywords: list[str]) -> Optional[str]:
if not txt or not keywords:
return None
highlighted_txt = txt
if question and not self.is_chinese(question):
highlighted_txt = re.sub(
r"(^|\W)(%s)(\W|$)" % re.escape(question),
r"\1<em>\2</em>\3", highlighted_txt,
flags=re.IGNORECASE | re.MULTILINE,
)
if re.search(r"<em>[^<>]+</em>", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE):
return highlighted_txt
for keyword in keywords:
highlighted_txt = re.sub(
r"(^|\W)(%s)(\W|$)" % re.escape(keyword),
r"\1<em>\2</em>\3", highlighted_txt,
flags=re.IGNORECASE | re.MULTILINE,
)
if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len(
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
return highlighted_txt
else:
return None
if not tks:
tks = rag_tokenizer.tokenize(txt)
tokens = tks.split()
if not tokens:
return None
last_pos = len(txt)
for i in range(len(tokens) - 1, -1, -1):
token = tokens[i]
token_pos = highlighted_txt.rfind(token, 0, last_pos)
if token_pos != -1:
if token in keywords:
highlighted_txt = (
highlighted_txt[:token_pos] +
f'<em>{token}</em>' +
highlighted_txt[token_pos + len(token):]
)
last_pos = token_pos
return re.sub(r'</em><em>', '', highlighted_txt)
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
if len(res.chunks) == 0 or len(keywords) == 0:
return ans
for d in res.chunks:
txt = d.get(fieldnm)
if not txt:
continue
tks = d.get("content_ltks") if fieldnm == "content_with_weight" else ""
highlighted_txt = self.highlight(txt, tks, " ".join(keywords), keywords)
if highlighted_txt:
ans[d["id"]] = highlighted_txt
return ans
def get_aggregation(self, res, fieldnm: str):
if len(res.chunks) == 0:
return []
counts = {}
result = []
for d in res.chunks:
if "value" in d and "count" in d:
# directly use the aggregation result
result.append((d["value"], d["count"]))
elif fieldnm in d:
# aggregate the values of specific field
v = d[fieldnm]
if isinstance(v, list):
for vv in v:
if isinstance(vv, str) and vv.strip():
counts[vv] = counts.get(vv, 0) + 1
elif isinstance(v, str) and v.strip():
counts[v] = counts.get(v, 0) + 1
if len(counts) > 0:
for k, v in counts.items():
result.append((k, v))
return result
"""
SQL
"""
def sql(self, sql: str, fetch_size: int = 1024, format: str = "json"):
logger.debug("OBConnection.sql get sql: %s", sql)
def normalize_sql(sql_text: str) -> str:
cleaned = sql_text.strip().rstrip(";")
cleaned = re.sub(r"[`]+", "", cleaned)
cleaned = re.sub(
r"json_extract_string\s*\(\s*([^,]+?)\s*,\s*([^)]+?)\s*\)",
r"JSON_UNQUOTE(JSON_EXTRACT(\1, \2))",
cleaned,
flags=re.IGNORECASE,
)
cleaned = re.sub(
r"json_extract_isnull\s*\(\s*([^,]+?)\s*,\s*([^)]+?)\s*\)",
r"(JSON_EXTRACT(\1, \2) IS NULL)",
cleaned,
flags=re.IGNORECASE,
)
return cleaned
def coerce_value(value: Any) -> Any:
if isinstance(value, np.generic):
return value.item()
if isinstance(value, bytes):
return value.decode("utf-8", errors="ignore")
return value
sql_text = normalize_sql(sql)
if fetch_size and fetch_size > 0:
sql_lower = sql_text.lstrip().lower()
if re.match(r"^(select|with)\b", sql_lower) and not re.search(r"\blimit\b", sql_lower):
sql_text = f"{sql_text} LIMIT {int(fetch_size)}"
logger.debug("OBConnection.sql to ob: %s", sql_text)
try:
res = self.client.perform_raw_text_sql(sql_text)
except Exception:
logger.exception("OBConnection.sql got exception")
raise
if res is None:
return None
columns = list(res.keys()) if hasattr(res, "keys") else []
try:
rows = res.fetchmany(fetch_size) if fetch_size and fetch_size > 0 else res.fetchall()
except Exception:
rows = res.fetchall()
rows_list = [[coerce_value(v) for v in list(row)] for row in rows]
result = {
"columns": [{"name": col, "type": "text"} for col in columns],
"rows": rows_list,
}
if format == "markdown":
header = "|" + "|".join(columns) + "|" if columns else ""
separator = "|" + "|".join(["---" for _ in columns]) + "|" if columns else ""
body = "\n".join(["|" + "|".join([str(v) for v in row]) + "|" for row in rows_list])
result["markdown"] = "\n".join([line for line in [header, separator, body] if line])
return result