mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-03 16:57:48 +08:00
Fix: failing p3 test for SDK/HTTP APIs (#13062)
### What problem does this PR solve? Adjust highlight parsing, add row-count SQL override, tweak retrieval thresholding, and update tests with engine-aware skips/utilities. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -1549,10 +1549,18 @@ async def retrieval_test(tenant_id):
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
if req.get("highlight") == "False" or req.get("highlight") == "false":
|
||||
highlight_val = req.get("highlight", None)
|
||||
if highlight_val is None:
|
||||
highlight = False
|
||||
elif isinstance(highlight_val, bool):
|
||||
highlight = highlight_val
|
||||
elif isinstance(highlight_val, str):
|
||||
if highlight_val.lower() in ["true", "false"]:
|
||||
highlight = highlight_val.lower() == "true"
|
||||
else:
|
||||
return get_error_data_result("`highlight` should be a boolean")
|
||||
else:
|
||||
highlight = True
|
||||
return get_error_data_result("`highlight` should be a boolean")
|
||||
try:
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
|
||||
@ -606,10 +606,21 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
|
||||
table_name = base_table
|
||||
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||||
|
||||
def is_row_count_question(q: str) -> bool:
|
||||
q = (q or "").lower()
|
||||
if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
|
||||
return False
|
||||
return bool(re.search(r"\bdataset\b|\btable\b|\bspreadsheet\b|\bexcel\b", q))
|
||||
|
||||
# Generate engine-specific SQL prompts
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity prompts with JSON extraction context
|
||||
json_field_names = list(field_map.keys())
|
||||
row_count_override = (
|
||||
f"SELECT COUNT(*) AS rows FROM {table_name}"
|
||||
if is_row_count_question(question)
|
||||
else None
|
||||
)
|
||||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||||
|
||||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||||
@ -641,6 +652,11 @@ Write SQL using json_extract_string() with exact field names. Include doc_id, do
|
||||
elif doc_engine == "oceanbase":
|
||||
# Build OceanBase prompts with JSON extraction context
|
||||
json_field_names = list(field_map.keys())
|
||||
row_count_override = (
|
||||
f"SELECT COUNT(*) AS rows FROM {table_name}"
|
||||
if is_row_count_question(question)
|
||||
else None
|
||||
)
|
||||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||||
|
||||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||||
@ -671,6 +687,7 @@ Write SQL using json_extract_string() with exact field names. Include doc_id, do
|
||||
)
|
||||
else:
|
||||
# Build ES/OS prompts with direct field access
|
||||
row_count_override = None
|
||||
sys_prompt = """You are a Database Administrator. Write SQL queries.
|
||||
|
||||
RULES:
|
||||
@ -693,8 +710,11 @@ Write SQL using exact field names above. Include doc_id, docnm_kwd for data quer
|
||||
tried_times = 0
|
||||
|
||||
async def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times, row_count_override
|
||||
if row_count_override:
|
||||
sql = row_count_override
|
||||
else:
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||
# Remove think blocks if present (format: </think>...)
|
||||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||
|
||||
Reference in New Issue
Block a user