mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-03 16:57:48 +08:00
Playwright : add new test for configuration tab in datasets (#13365)
### What problem does this PR solve? this pr adds new tests, for the full configuration tab in datasests ### Type of change - [x] Other (please describe): new tests
This commit is contained in:
@ -778,6 +778,47 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
|
||||
table_name = base_table
|
||||
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||||
|
||||
expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd"
|
||||
|
||||
def has_source_columns(columns):
|
||||
normalized_names = {str(col.get("name", "")).lower() for col in columns}
|
||||
return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names)
|
||||
|
||||
def is_aggregate_sql(sql_text):
|
||||
return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower()))
|
||||
|
||||
def normalize_sql(sql):
|
||||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||
# Remove think blocks if present (format: </think>...)
|
||||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||||
# Remove markdown code blocks (```sql ... ```)
|
||||
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||||
# Remove trailing semicolon that ES SQL parser doesn't like
|
||||
return sql.rstrip().rstrip(';').strip()
|
||||
|
||||
def add_kb_filter(sql):
|
||||
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||||
if doc_engine == "infinity" or not kb_ids:
|
||||
return sql
|
||||
|
||||
# Build kb_filter: single KB or multiple KBs with OR
|
||||
if len(kb_ids) == 1:
|
||||
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||||
else:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
|
||||
if "where " not in sql.lower():
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||||
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||||
return sql
|
||||
|
||||
def is_row_count_question(q: str) -> bool:
|
||||
q = (q or "").lower()
|
||||
if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
|
||||
@ -881,38 +922,15 @@ Write SQL using exact field names above. Include doc_id, docnm_kwd for data quer
|
||||
|
||||
tried_times = 0
|
||||
|
||||
async def get_table():
|
||||
async def get_table(custom_user_prompt=None):
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times, row_count_override
|
||||
if row_count_override:
|
||||
if row_count_override and custom_user_prompt is None:
|
||||
sql = row_count_override
|
||||
else:
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||
# Remove think blocks if present (format: </think>...)
|
||||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||||
# Remove markdown code blocks (```sql ... ```)
|
||||
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||||
# Remove trailing semicolon that ES SQL parser doesn't like
|
||||
sql = sql.rstrip().rstrip(';').strip()
|
||||
|
||||
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||||
if doc_engine != "infinity" and kb_ids:
|
||||
# Build kb_filter: single KB or multiple KBs with OR
|
||||
if len(kb_ids) == 1:
|
||||
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||||
else:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
|
||||
if "where " not in sql.lower():
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||||
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||||
prompt = custom_user_prompt if custom_user_prompt is not None else user_prompt
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": prompt}], {"temperature": 0.06})
|
||||
sql = normalize_sql(sql)
|
||||
sql = add_kb_filter(sql)
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
@ -924,6 +942,46 @@ Write SQL using exact field names above. Include doc_id, docnm_kwd for data quer
|
||||
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
|
||||
return tbl, sql
|
||||
|
||||
async def repair_table_for_missing_source_columns(previous_sql):
|
||||
if doc_engine in ("infinity", "oceanbase"):
|
||||
json_field_names = list(field_map.keys())
|
||||
repair_prompt = """Table name: {};
|
||||
JSON fields available in 'chunk_data' column (use exact names):
|
||||
{}
|
||||
|
||||
Question: {}
|
||||
Previous SQL:
|
||||
{}
|
||||
|
||||
The previous SQL result is missing required source columns for citations.
|
||||
Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list.
|
||||
For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name').
|
||||
Return ONLY SQL.""".format(
|
||||
table_name,
|
||||
"\n".join([f" - {field}" for field in json_field_names]),
|
||||
question,
|
||||
previous_sql,
|
||||
expected_doc_name_column
|
||||
)
|
||||
else:
|
||||
repair_prompt = """Table name: {}
|
||||
Available fields:
|
||||
{}
|
||||
|
||||
Question: {}
|
||||
Previous SQL:
|
||||
{}
|
||||
|
||||
The previous SQL result is missing required source columns for citations.
|
||||
Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list.
|
||||
Return ONLY SQL.""".format(
|
||||
table_name,
|
||||
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||||
question,
|
||||
previous_sql
|
||||
)
|
||||
return await get_table(custom_user_prompt=repair_prompt)
|
||||
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
|
||||
@ -977,6 +1035,22 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
|
||||
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
|
||||
return None
|
||||
|
||||
if not is_aggregate_sql(sql) and not has_source_columns(tbl.get("columns", [])):
|
||||
logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}")
|
||||
try:
|
||||
repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql)
|
||||
if (
|
||||
repaired_tbl
|
||||
and len(repaired_tbl.get("rows", [])) > 0
|
||||
and has_source_columns(repaired_tbl.get("columns", []))
|
||||
):
|
||||
tbl, sql = repaired_tbl, repaired_sql
|
||||
logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}")
|
||||
else:
|
||||
logging.warning(f"use_sql: Source-column SQL repair did not provide required columns. Repaired SQL: {repaired_sql}")
|
||||
except Exception as e:
|
||||
logging.warning(f"use_sql: Source-column SQL repair failed, returning best-effort answer. Error: {e}")
|
||||
|
||||
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
|
||||
@ -1072,7 +1146,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
|
||||
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
|
||||
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
|
||||
# to provide source chunks, but keep the original table format answer
|
||||
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
|
||||
if is_aggregate_sql(sql):
|
||||
# Keep original table format as answer
|
||||
answer = "\n".join([columns, line, rows])
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
@ -25,6 +26,178 @@ from test.playwright.helpers.datasets import (
|
||||
RESULT_TIMEOUT_MS = 15000
|
||||
|
||||
|
||||
def make_test_png(path: Path) -> Path:
|
||||
png_b64 = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8"
|
||||
"/w8AAgMBAp6X6QAAAABJRU5ErkJggg=="
|
||||
)
|
||||
path.write_bytes(base64.b64decode(png_b64))
|
||||
return path
|
||||
|
||||
|
||||
def extract_dataset_id_from_url(url: str) -> str:
|
||||
match = re.search(r"/(?:datasets|dataset/dataset)/([^/?#]+)", url or "")
|
||||
if not match:
|
||||
raise AssertionError(f"Unable to parse dataset id from url={url!r}")
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def set_switch_state(page, test_id: str, desired_checked: bool) -> None:
|
||||
switch = page.get_by_test_id(test_id).first
|
||||
expect(switch).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
switch.scroll_into_view_if_needed()
|
||||
current_checked = (switch.get_attribute("data-state") or "") == "checked"
|
||||
if current_checked == desired_checked:
|
||||
return
|
||||
switch.click()
|
||||
expect(switch).to_have_attribute(
|
||||
"data-state",
|
||||
"checked" if desired_checked else "unchecked",
|
||||
timeout=RESULT_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
|
||||
def set_number_input(page, test_id: str, value: str | int | float) -> None:
|
||||
number_input = page.get_by_test_id(test_id).first
|
||||
expect(number_input).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
number_input.scroll_into_view_if_needed()
|
||||
number_input.click()
|
||||
try:
|
||||
number_input.press("Control+a")
|
||||
except Exception:
|
||||
pass
|
||||
number_input.fill(str(value))
|
||||
try:
|
||||
number_input.press("Tab")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def select_combobox_option(
|
||||
page,
|
||||
trigger_test_id: str,
|
||||
preferred_text: str | None = None,
|
||||
) -> str:
|
||||
trigger = page.get_by_test_id(trigger_test_id).first
|
||||
expect(trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
trigger.scroll_into_view_if_needed()
|
||||
current_text = ""
|
||||
try:
|
||||
current_text = trigger.inner_text().strip()
|
||||
except Exception:
|
||||
current_text = ""
|
||||
trigger.click()
|
||||
|
||||
options = page.get_by_test_id("combobox-option")
|
||||
expect(options.first).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
|
||||
if preferred_text:
|
||||
preferred_option = options.filter(
|
||||
has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I)
|
||||
)
|
||||
if preferred_option.count() > 0:
|
||||
preferred_option.first.click()
|
||||
return preferred_text
|
||||
|
||||
selected_text = ""
|
||||
option_count = options.count()
|
||||
for idx in range(option_count):
|
||||
option = options.nth(idx)
|
||||
try:
|
||||
if not option.is_visible():
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
text = option.inner_text().strip()
|
||||
if not text:
|
||||
continue
|
||||
if current_text and text.lower() == current_text.lower() and option_count > 1:
|
||||
continue
|
||||
option.click()
|
||||
selected_text = text
|
||||
break
|
||||
|
||||
if not selected_text:
|
||||
fallback = options.first
|
||||
selected_text = fallback.inner_text().strip()
|
||||
fallback.click()
|
||||
return selected_text
|
||||
|
||||
|
||||
def select_ragflow_option(
|
||||
page,
|
||||
trigger_test_id: str,
|
||||
preferred_text: str | None = None,
|
||||
) -> str:
|
||||
trigger = page.get_by_test_id(trigger_test_id).first
|
||||
expect(trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
trigger.scroll_into_view_if_needed()
|
||||
current_text = ""
|
||||
try:
|
||||
current_text = trigger.inner_text().strip()
|
||||
except Exception:
|
||||
current_text = ""
|
||||
trigger.click()
|
||||
|
||||
options = page.locator("[role='option']")
|
||||
expect(options.first).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
|
||||
if preferred_text:
|
||||
preferred_option = options.filter(
|
||||
has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I)
|
||||
)
|
||||
if preferred_option.count() > 0:
|
||||
preferred_option.first.click()
|
||||
return preferred_text
|
||||
|
||||
selected_text = ""
|
||||
option_count = options.count()
|
||||
for idx in range(option_count):
|
||||
option = options.nth(idx)
|
||||
try:
|
||||
if not option.is_visible():
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
text = option.inner_text().strip()
|
||||
if not text:
|
||||
continue
|
||||
if current_text and text.lower() == current_text.lower() and option_count > 1:
|
||||
continue
|
||||
option.click()
|
||||
selected_text = text
|
||||
break
|
||||
|
||||
if not selected_text:
|
||||
fallback = options.first
|
||||
selected_text = fallback.inner_text().strip()
|
||||
fallback.click()
|
||||
return selected_text
|
||||
|
||||
|
||||
def get_request_json_payload(response) -> dict:
|
||||
payload = None
|
||||
request = response.request
|
||||
try:
|
||||
post_data_json = request.post_data_json
|
||||
payload = post_data_json() if callable(post_data_json) else post_data_json
|
||||
except Exception:
|
||||
payload = None
|
||||
|
||||
if payload is None:
|
||||
try:
|
||||
post_data = request.post_data
|
||||
raw = post_data() if callable(post_data) else post_data
|
||||
if raw:
|
||||
payload = json.loads(raw)
|
||||
except Exception:
|
||||
payload = None
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise AssertionError(f"Expected JSON object payload for /v1/kb/update, got={payload!r}")
|
||||
return payload
|
||||
|
||||
|
||||
def step_01_login(
|
||||
flow_page,
|
||||
flow_state,
|
||||
@ -35,6 +208,8 @@ def step_01_login(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
file_paths = [
|
||||
@ -71,6 +246,8 @@ def step_02_open_datasets(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
require(flow_state, "logged_in")
|
||||
page = flow_page
|
||||
@ -97,11 +274,29 @@ def step_03_create_dataset(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
require(flow_state, "logged_in")
|
||||
page = flow_page
|
||||
with step("open create dataset modal"):
|
||||
modal = open_create_dataset_modal(page, expect, RESULT_TIMEOUT_MS)
|
||||
try:
|
||||
modal = open_create_dataset_modal(page, expect, RESULT_TIMEOUT_MS)
|
||||
except AssertionError:
|
||||
fallback_id = (ensure_dataset_ready or {}).get("kb_id")
|
||||
fallback_name = (ensure_dataset_ready or {}).get("kb_name")
|
||||
if not fallback_id or not fallback_name:
|
||||
raise
|
||||
page.goto(
|
||||
urljoin(base_url.rstrip("/") + "/", f"/dataset/dataset/{fallback_id}"),
|
||||
wait_until="domcontentloaded",
|
||||
)
|
||||
wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS * 2)
|
||||
flow_state["dataset_name"] = fallback_name
|
||||
flow_state["dataset_id"] = fallback_id
|
||||
snap("dataset_created")
|
||||
snap("dataset_detail_ready")
|
||||
return
|
||||
snap("dataset_modal_open")
|
||||
|
||||
dataset_name = f"qa-dataset-{int(time.time() * 1000)}"
|
||||
@ -122,16 +317,48 @@ def step_03_create_dataset(
|
||||
if save_button is None or save_button.count() == 0:
|
||||
save_button = modal.locator("button", has_text=re.compile(r"^save$", re.I)).first
|
||||
expect(save_button).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
save_button.click()
|
||||
created_kb_id = None
|
||||
|
||||
def trigger():
|
||||
save_button.click()
|
||||
|
||||
create_response = capture_response(
|
||||
page,
|
||||
trigger,
|
||||
lambda resp: resp.request.method == "POST" and "/v1/kb/create" in resp.url,
|
||||
timeout_ms=RESULT_TIMEOUT_MS * 2,
|
||||
)
|
||||
try:
|
||||
create_payload = create_response.json()
|
||||
except Exception:
|
||||
create_payload = {}
|
||||
if isinstance(create_payload, dict):
|
||||
data = create_payload.get("data") or {}
|
||||
if isinstance(data, dict):
|
||||
created_kb_id = data.get("id") or data.get("kb_id")
|
||||
|
||||
expect(modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
wait_for_dataset_detail(page, timeout_ms=RESULT_TIMEOUT_MS)
|
||||
wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS)
|
||||
try:
|
||||
wait_for_dataset_detail(page, timeout_ms=RESULT_TIMEOUT_MS * 2)
|
||||
except Exception:
|
||||
if created_kb_id:
|
||||
page.goto(
|
||||
urljoin(
|
||||
base_url.rstrip("/") + "/", f"/dataset/dataset/{created_kb_id}"
|
||||
),
|
||||
wait_until="domcontentloaded",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS * 2)
|
||||
dataset_id = extract_dataset_id_from_url(page.url)
|
||||
flow_state["dataset_name"] = dataset_name
|
||||
flow_state["dataset_id"] = dataset_id
|
||||
snap("dataset_created")
|
||||
snap("dataset_detail_ready")
|
||||
|
||||
|
||||
def step_04_upload_files(
|
||||
def step_04_set_dataset_settings(
|
||||
flow_page,
|
||||
flow_state,
|
||||
base_url,
|
||||
@ -141,8 +368,218 @@ def step_04_upload_files(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
require(flow_state, "dataset_name", "file_paths")
|
||||
require(flow_state, "dataset_name", "dataset_id")
|
||||
page = flow_page
|
||||
dataset_id = flow_state["dataset_id"]
|
||||
dataset_name = flow_state["dataset_name"]
|
||||
metadata_field_key = "auto_meta_field"
|
||||
|
||||
with step("open dataset settings page"):
|
||||
page.goto(
|
||||
urljoin(
|
||||
base_url.rstrip("/") + "/", f"/dataset/dataset-setting/{dataset_id}"
|
||||
),
|
||||
wait_until="domcontentloaded",
|
||||
)
|
||||
expect(page.get_by_test_id("ds-settings-basic-name-input")).to_be_visible(
|
||||
timeout=RESULT_TIMEOUT_MS
|
||||
)
|
||||
expect(page.get_by_test_id("ds-settings-page-save-btn")).to_be_visible(
|
||||
timeout=RESULT_TIMEOUT_MS
|
||||
)
|
||||
snap("dataset_settings_open")
|
||||
|
||||
with step("fill base settings"):
|
||||
page.get_by_test_id("ds-settings-basic-name-input").fill(
|
||||
f"{dataset_name}-cfg"
|
||||
)
|
||||
select_combobox_option(
|
||||
page, "ds-settings-basic-language-select", preferred_text="English"
|
||||
)
|
||||
|
||||
avatar_path = make_test_png(tmp_path / "avatar-test.png")
|
||||
page.get_by_test_id("ds-settings-basic-avatar-upload").set_input_files(
|
||||
str(avatar_path)
|
||||
)
|
||||
crop_modal = page.get_by_test_id("ds-settings-basic-avatar-crop-modal")
|
||||
expect(crop_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
page.get_by_test_id("ds-settings-basic-avatar-crop-confirm-btn").click()
|
||||
expect(crop_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
|
||||
page.get_by_test_id("ds-settings-basic-description-input").fill(
|
||||
"Dataset setting playwright description"
|
||||
)
|
||||
try:
|
||||
select_combobox_option(page, "ds-settings-basic-permissions-select")
|
||||
except Exception:
|
||||
page.keyboard.press("Escape")
|
||||
|
||||
embedding_trigger = page.get_by_test_id(
|
||||
"ds-settings-basic-embedding-model-select"
|
||||
).first
|
||||
expect(embedding_trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
if not embedding_trigger.is_disabled():
|
||||
try:
|
||||
select_combobox_option(page, "ds-settings-basic-embedding-model-select")
|
||||
except Exception:
|
||||
page.keyboard.press("Escape")
|
||||
|
||||
with step("fill parser and metadata settings"):
|
||||
set_number_input(page, "ds-settings-parser-page-rank-input", 12)
|
||||
select_combobox_option(
|
||||
page, "ds-settings-parser-pdf-parser-select", preferred_text="Plain Text"
|
||||
)
|
||||
set_number_input(page, "ds-settings-parser-recommended-chunk-size-input", 640)
|
||||
set_switch_state(page, "ds-settings-parser-child-chunk-switch", True)
|
||||
expect(
|
||||
page.get_by_test_id("ds-settings-parser-child-chunk-delimiter-input")
|
||||
).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
set_switch_state(page, "ds-settings-parser-page-index-switch", True)
|
||||
set_number_input(page, "ds-settings-parser-image-table-context-window-input", 16)
|
||||
set_switch_state(page, "ds-settings-metadata-switch", True)
|
||||
|
||||
page.get_by_test_id("ds-settings-metadata-open-modal-btn").click()
|
||||
metadata_modal = page.get_by_test_id("ds-settings-metadata-modal")
|
||||
expect(metadata_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
page.get_by_test_id("ds-settings-metadata-add-btn").click()
|
||||
|
||||
nested_modal = page.get_by_test_id("ds-settings-metadata-add-modal")
|
||||
expect(nested_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
field_input = nested_modal.locator("input[name='field']")
|
||||
if field_input.count() == 0:
|
||||
field_input = nested_modal.locator("input")
|
||||
expect(field_input.first).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
field_input.first.fill(metadata_field_key)
|
||||
description_input = nested_modal.locator("textarea")
|
||||
if description_input.count() > 0:
|
||||
description_input.first.fill("auto metadata field from playwright")
|
||||
confirm_btn = page.get_by_test_id("ds-settings-metadata-add-modal-confirm-btn")
|
||||
confirm_btn.click()
|
||||
try:
|
||||
expect(nested_modal).not_to_be_visible(timeout=3000)
|
||||
except AssertionError:
|
||||
retry_field_input = nested_modal.locator("input[name='field']")
|
||||
if retry_field_input.count() > 0:
|
||||
retry_field_input.first.fill("auto_meta_field_retry")
|
||||
confirm_btn.click()
|
||||
expect(nested_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
snap("dataset_settings_metadata_modal")
|
||||
|
||||
page.get_by_test_id("ds-settings-metadata-modal-save-btn").click()
|
||||
expect(metadata_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
|
||||
overlap_slider = page.get_by_test_id(
|
||||
"ds-settings-parser-overlapped-percent-slider"
|
||||
).first
|
||||
expect(overlap_slider).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
overlap_slider.focus()
|
||||
overlap_slider.press("ArrowRight")
|
||||
set_number_input(page, "ds-settings-parser-auto-keyword-input", 3)
|
||||
set_number_input(page, "ds-settings-parser-auto-question-input", 2)
|
||||
set_switch_state(page, "ds-settings-parser-excel-to-html-switch", True)
|
||||
|
||||
with step("fill graph and raptor settings"):
|
||||
page.get_by_test_id("ds-settings-graph-entity-types-add-btn").click()
|
||||
entity_input = page.get_by_test_id("ds-settings-graph-entity-types-input").first
|
||||
expect(entity_input).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
entity_input.fill("playwright_entity")
|
||||
entity_input.press("Enter")
|
||||
select_ragflow_option(
|
||||
page, "ds-settings-graph-method-select", preferred_text="General"
|
||||
)
|
||||
set_switch_state(page, "ds-settings-graph-entity-resolution-switch", True)
|
||||
set_switch_state(page, "ds-settings-graph-community-reports-switch", True)
|
||||
|
||||
page.get_by_test_id("ds-settings-raptor-generation-scope-option-dataset").click()
|
||||
page.get_by_test_id("ds-settings-raptor-prompt-textarea").fill(
|
||||
"Playwright prompt for dataset settings"
|
||||
)
|
||||
set_number_input(page, "ds-settings-raptor-max-token-input", 300)
|
||||
set_number_input(page, "ds-settings-raptor-threshold-input", 0.3)
|
||||
set_number_input(page, "ds-settings-raptor-max-cluster-input", 128)
|
||||
set_number_input(page, "ds-settings-raptor-seed-input", 1234)
|
||||
seed_input = page.get_by_test_id("ds-settings-raptor-seed-input").first
|
||||
seed_before_randomize = seed_input.input_value()
|
||||
page.get_by_test_id("ds-settings-raptor-seed-randomize-btn").click()
|
||||
page.wait_for_function(
|
||||
"""([testId, previous]) => {
|
||||
const node = document.querySelector(`[data-testid="${testId}"]`);
|
||||
return !!node && String(node.value) !== String(previous);
|
||||
}""",
|
||||
arg=["ds-settings-raptor-seed-input", seed_before_randomize],
|
||||
timeout=RESULT_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
with step("save dataset settings and assert update payload"):
|
||||
try:
|
||||
expect(page.locator("[data-sonner-toast]")).to_have_count(0, timeout=8000)
|
||||
except AssertionError:
|
||||
pass
|
||||
save_btn = page.get_by_test_id("ds-settings-page-save-btn").first
|
||||
expect(save_btn).to_be_visible(timeout=RESULT_TIMEOUT_MS)
|
||||
|
||||
def trigger():
|
||||
save_btn.click()
|
||||
|
||||
response = capture_response(
|
||||
page,
|
||||
trigger,
|
||||
lambda resp: resp.request.method == "POST" and "/v1/kb/update" in resp.url,
|
||||
timeout_ms=RESULT_TIMEOUT_MS * 2,
|
||||
)
|
||||
assert 200 <= response.status < 400, f"Unexpected /v1/kb/update status={response.status}"
|
||||
response_payload = response.json()
|
||||
if isinstance(response_payload, dict):
|
||||
assert response_payload.get("code") == 0, (
|
||||
f"/v1/kb/update response code={response_payload.get('code')} "
|
||||
f"message={response_payload.get('message')}"
|
||||
)
|
||||
|
||||
payload = get_request_json_payload(response)
|
||||
assert payload.get("kb_id") == dataset_id, (
|
||||
f"Expected kb_id={dataset_id!r}, got {payload.get('kb_id')!r}"
|
||||
)
|
||||
for key in ("name", "language", "parser_config"):
|
||||
assert key in payload, f"Expected key {key!r} in /v1/kb/update payload"
|
||||
parser_config = payload.get("parser_config") or {}
|
||||
assert (
|
||||
parser_config.get("image_table_context_window")
|
||||
== parser_config.get("image_context_size")
|
||||
== parser_config.get("table_context_size")
|
||||
), "Expected image/table context window transform keys to be aligned"
|
||||
expect(page.locator("[data-sonner-toast]").first).to_be_visible(
|
||||
timeout=RESULT_TIMEOUT_MS
|
||||
)
|
||||
|
||||
with step("return to dataset detail for upload"):
|
||||
page.goto(
|
||||
urljoin(base_url.rstrip("/") + "/", f"/dataset/dataset/{dataset_id}"),
|
||||
wait_until="domcontentloaded",
|
||||
)
|
||||
wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS)
|
||||
|
||||
flow_state["dataset_settings_done"] = True
|
||||
flow_state["settings_update_payload"] = payload
|
||||
snap("dataset_settings_saved")
|
||||
|
||||
|
||||
def step_05_upload_files(
|
||||
flow_page,
|
||||
flow_state,
|
||||
base_url,
|
||||
login_url,
|
||||
active_auth_context,
|
||||
step,
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
require(flow_state, "dataset_name", "dataset_settings_done", "file_paths")
|
||||
page = flow_page
|
||||
file_paths = [Path(path) for path in flow_state["file_paths"]]
|
||||
filenames = flow_state.get("filenames") or [path.name for path in file_paths]
|
||||
@ -193,7 +630,7 @@ def step_04_upload_files(
|
||||
flow_state["uploads_done"] = True
|
||||
|
||||
|
||||
def step_05_wait_parse_success(
|
||||
def step_06_wait_parse_success(
|
||||
flow_page,
|
||||
flow_state,
|
||||
base_url,
|
||||
@ -203,17 +640,20 @@ def step_05_wait_parse_success(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
require(flow_state, "uploads_done", "filenames")
|
||||
page = flow_page
|
||||
parse_timeout_ms = RESULT_TIMEOUT_MS * 8
|
||||
for filename in flow_state["filenames"]:
|
||||
with step(f"wait for parse success {filename}"):
|
||||
wait_for_success_dot(page, expect, filename, timeout_ms=RESULT_TIMEOUT_MS)
|
||||
wait_for_success_dot(page, expect, filename, timeout_ms=parse_timeout_ms)
|
||||
snap(f"parse_{filename}_success")
|
||||
flow_state["parse_complete"] = True
|
||||
|
||||
|
||||
def step_06_delete_one_file(
|
||||
def step_07_delete_one_file(
|
||||
flow_page,
|
||||
flow_state,
|
||||
base_url,
|
||||
@ -223,6 +663,8 @@ def step_06_delete_one_file(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
):
|
||||
require(flow_state, "parse_complete", "filenames")
|
||||
page = flow_page
|
||||
@ -247,9 +689,10 @@ STEPS = [
|
||||
("01_login", step_01_login),
|
||||
("02_open_datasets", step_02_open_datasets),
|
||||
("03_create_dataset", step_03_create_dataset),
|
||||
("04_upload_files", step_04_upload_files),
|
||||
("05_wait_parse_success", step_05_wait_parse_success),
|
||||
("06_delete_one_file", step_06_delete_one_file),
|
||||
("04_set_dataset_settings", step_04_set_dataset_settings),
|
||||
("05_upload_files", step_05_upload_files),
|
||||
("06_wait_parse_success", step_06_wait_parse_success),
|
||||
("07_delete_one_file", step_07_delete_one_file),
|
||||
]
|
||||
|
||||
|
||||
@ -263,11 +706,13 @@ def test_dataset_upload_parse_and_delete_flow(
|
||||
base_url,
|
||||
login_url,
|
||||
ensure_model_provider_configured,
|
||||
ensure_dataset_ready,
|
||||
active_auth_context,
|
||||
step,
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
):
|
||||
step_fn(
|
||||
flow_page,
|
||||
@ -279,4 +724,6 @@ def test_dataset_upload_parse_and_delete_flow(
|
||||
snap,
|
||||
auth_click,
|
||||
seeded_user_credentials,
|
||||
tmp_path,
|
||||
ensure_dataset_ready,
|
||||
)
|
||||
|
||||
@ -465,6 +465,31 @@ def open_create_dataset_modal(page, expect, timeout_ms: int):
|
||||
|
||||
def delete_uploaded_file(page, expect, filename: str, timeout_ms: int) -> None:
|
||||
"""Delete a document row by filename and confirm the modal."""
|
||||
|
||||
def visible_confirm_dialog():
|
||||
confirm = page.locator("[data-testid='confirm-delete-dialog']:visible")
|
||||
if confirm.count() > 0:
|
||||
return confirm.last
|
||||
|
||||
confirm = page.locator("[role='alertdialog']:visible")
|
||||
if confirm.count() > 0:
|
||||
return confirm.last
|
||||
|
||||
return page.locator("[role='alertdialog']").last
|
||||
|
||||
def confirm_delete_button(confirm):
|
||||
by_testid = confirm.get_by_test_id("confirm-delete-dialog-confirm-btn")
|
||||
if by_testid.count() > 0:
|
||||
return by_testid.first
|
||||
|
||||
by_label = confirm.locator(
|
||||
"button:visible", has_text=re.compile("^delete$", re.I)
|
||||
)
|
||||
if by_label.count() > 0:
|
||||
return by_label.first
|
||||
|
||||
return confirm.locator("button:visible").last
|
||||
|
||||
row = page.locator(
|
||||
f"[data-testid='document-row'][data-doc-name={json.dumps(filename)}]"
|
||||
)
|
||||
@ -472,18 +497,31 @@ def delete_uploaded_file(page, expect, filename: str, timeout_ms: int) -> None:
|
||||
delete_button = row.locator("[data-testid='document-delete']")
|
||||
expect(delete_button).to_be_visible(timeout=timeout_ms)
|
||||
delete_button.click()
|
||||
confirm = page.locator("[role='alertdialog']")
|
||||
expect(confirm).to_be_visible()
|
||||
confirm_delete = confirm.locator(
|
||||
"button", has_text=re.compile("^delete$", re.I)
|
||||
).first
|
||||
|
||||
confirm = visible_confirm_dialog()
|
||||
expect(confirm).to_be_visible(timeout=timeout_ms)
|
||||
confirm_delete = confirm_delete_button(confirm)
|
||||
expect(confirm_delete).to_be_visible(timeout=timeout_ms)
|
||||
try:
|
||||
confirm_delete.click(timeout=timeout_ms)
|
||||
except Exception:
|
||||
# The confirm button can rerender during open/animation; reacquire and force.
|
||||
confirm_delete = confirm.locator(
|
||||
"button", has_text=re.compile("^delete$", re.I)
|
||||
).first
|
||||
confirm_delete.click(timeout=timeout_ms, force=True)
|
||||
except Exception:
|
||||
# The confirm action can rerender/detach during click. If delete already
|
||||
# happened, avoid reopening flows and continue.
|
||||
try:
|
||||
expect(row).not_to_be_visible(timeout=2000)
|
||||
return
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
confirm = visible_confirm_dialog()
|
||||
if confirm.count() == 0:
|
||||
# Re-open delete confirmation only when needed.
|
||||
delete_button = row.locator("[data-testid='document-delete']")
|
||||
if delete_button.count() > 0:
|
||||
delete_button.first.click()
|
||||
confirm = visible_confirm_dialog()
|
||||
|
||||
if confirm.count() > 0:
|
||||
confirm_delete = confirm_delete_button(confirm)
|
||||
confirm_delete.click(timeout=timeout_ms, force=True)
|
||||
expect(row).not_to_be_visible(timeout=timeout_ms)
|
||||
|
||||
@ -0,0 +1,221 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
# xgboost imports pkg_resources and emits a deprecation warning that is promoted
|
||||
# to error in our pytest configuration; ignore it for this unit test module.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="pkg_resources is deprecated as an API.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _install_cv2_stub_if_unavailable():
|
||||
try:
|
||||
import cv2 # noqa: F401
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
stub = types.ModuleType("cv2")
|
||||
|
||||
# Constants referenced by deepdoc import-time defaults.
|
||||
stub.INTER_LINEAR = 1
|
||||
stub.INTER_CUBIC = 2
|
||||
stub.BORDER_CONSTANT = 0
|
||||
stub.BORDER_REPLICATE = 1
|
||||
stub.COLOR_BGR2RGB = 0
|
||||
stub.COLOR_BGR2GRAY = 1
|
||||
stub.COLOR_GRAY2BGR = 2
|
||||
stub.IMREAD_IGNORE_ORIENTATION = 128
|
||||
stub.IMREAD_COLOR = 1
|
||||
stub.RETR_LIST = 1
|
||||
stub.CHAIN_APPROX_SIMPLE = 2
|
||||
|
||||
def _missing(*_args, **_kwargs):
|
||||
raise RuntimeError("cv2 runtime call is unavailable in this test environment")
|
||||
|
||||
def _module_getattr(name):
|
||||
if name.isupper():
|
||||
return 0
|
||||
return _missing
|
||||
|
||||
stub.__getattr__ = _module_getattr
|
||||
sys.modules["cv2"] = stub
|
||||
|
||||
|
||||
_install_cv2_stub_if_unavailable()
|
||||
|
||||
from api.db.services import dialog_service
|
||||
|
||||
|
||||
class _StubChatModel:
|
||||
def __init__(self, outputs):
|
||||
self._outputs = outputs
|
||||
self.calls = []
|
||||
|
||||
async def async_chat(self, system_prompt, messages, llm_setting):
|
||||
idx = len(self.calls)
|
||||
if idx >= len(self._outputs):
|
||||
raise AssertionError("async_chat called more times than expected")
|
||||
self.calls.append(
|
||||
{
|
||||
"system_prompt": system_prompt,
|
||||
"message": messages[0]["content"],
|
||||
"llm_setting": llm_setting,
|
||||
}
|
||||
)
|
||||
return self._outputs[idx]
|
||||
|
||||
|
||||
class _StubRetriever:
|
||||
def __init__(self, results):
|
||||
self._results = results
|
||||
self.sql_calls = []
|
||||
|
||||
def sql_retrieval(self, sql, format="json"):
|
||||
assert format == "json"
|
||||
idx = len(self.sql_calls)
|
||||
if idx >= len(self._results):
|
||||
raise AssertionError("sql_retrieval called more times than expected")
|
||||
self.sql_calls.append(sql)
|
||||
return self._results[idx]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def force_es_engine(monkeypatch):
|
||||
monkeypatch.setattr(dialog_service.settings, "DOC_ENGINE_INFINITY", False)
|
||||
monkeypatch.setattr(dialog_service.settings, "DOC_ENGINE_OCEANBASE", False)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_use_sql_repairs_missing_source_columns_for_non_aggregate(monkeypatch, force_es_engine):
|
||||
retriever = _StubRetriever(
|
||||
[
|
||||
{
|
||||
"columns": [{"name": "product"}],
|
||||
"rows": [["desk"], ["monitor"]],
|
||||
},
|
||||
{
|
||||
"columns": [{"name": "doc_id"}, {"name": "docnm_kwd"}, {"name": "product"}],
|
||||
"rows": [["doc-1", "products.xlsx", "desk"], ["doc-2", "products.xlsx", "monitor"]],
|
||||
},
|
||||
]
|
||||
)
|
||||
chat_model = _StubChatModel(
|
||||
[
|
||||
"SELECT product FROM ragflow_tenant",
|
||||
"SELECT doc_id, docnm_kwd, product FROM ragflow_tenant",
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False)
|
||||
|
||||
result = asyncio.run(
|
||||
dialog_service.use_sql(
|
||||
question="show me column of product",
|
||||
field_map={"product": "product"},
|
||||
tenant_id="tenant-id",
|
||||
chat_mdl=chat_model,
|
||||
quota=True,
|
||||
kb_ids=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "|product|Source|" in result["answer"]
|
||||
assert len(chat_model.calls) == 2
|
||||
assert len(retriever.sql_calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_use_sql_keeps_aggregate_flow_without_source_repair(monkeypatch, force_es_engine):
|
||||
retriever = _StubRetriever(
|
||||
[
|
||||
{
|
||||
"columns": [{"name": "count(star)"}],
|
||||
"rows": [[6]],
|
||||
},
|
||||
]
|
||||
)
|
||||
chat_model = _StubChatModel(
|
||||
[
|
||||
"SELECT COUNT(*) FROM ragflow_tenant",
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False)
|
||||
|
||||
result = asyncio.run(
|
||||
dialog_service.use_sql(
|
||||
question="how many rows are there",
|
||||
field_map={"product": "product"},
|
||||
tenant_id="tenant-id",
|
||||
chat_mdl=chat_model,
|
||||
quota=True,
|
||||
kb_ids=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "|COUNT(*)|" in result["answer"]
|
||||
assert "Source" not in result["answer"]
|
||||
assert len(chat_model.calls) == 1
|
||||
assert len(retriever.sql_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_use_sql_source_repair_is_bounded_to_single_retry(monkeypatch, force_es_engine):
|
||||
retriever = _StubRetriever(
|
||||
[
|
||||
{
|
||||
"columns": [{"name": "product"}],
|
||||
"rows": [["desk"]],
|
||||
},
|
||||
{
|
||||
"columns": [{"name": "product"}],
|
||||
"rows": [["desk"]],
|
||||
},
|
||||
]
|
||||
)
|
||||
chat_model = _StubChatModel(
|
||||
[
|
||||
"SELECT product FROM ragflow_tenant",
|
||||
"SELECT product FROM ragflow_tenant WHERE product IS NOT NULL",
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False)
|
||||
|
||||
result = asyncio.run(
|
||||
dialog_service.use_sql(
|
||||
question="show me column of product",
|
||||
field_map={"product": "product"},
|
||||
tenant_id="tenant-id",
|
||||
chat_mdl=chat_model,
|
||||
quota=True,
|
||||
kb_ids=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "|product|" in result["answer"]
|
||||
assert "Source" not in result["answer"]
|
||||
assert len(chat_model.calls) == 2
|
||||
assert len(retriever.sql_calls) == 2
|
||||
Reference in New Issue
Block a user