diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 91c20fddf..1b520ec29 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -226,6 +226,9 @@ async def add_llm(): elif factory == "PaddleOCR": api_key = apikey_json(["api_key", "provider_order"]) + elif factory == "OpenDataLoader": + api_key = apikey_json(["api_key", "provider_order"]) + llm = { "tenant_id": current_user.id, "llm_factory": factory, @@ -390,6 +393,7 @@ async def delete_factory(): def my_llms(): try: TenantLLMService.ensure_mineru_from_env(current_user.id) + TenantLLMService.ensure_opendataloader_from_env(current_user.id) include_details = request.args.get("include_details", "false").lower() == "true" if include_details: diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index a27f1352d..fe99aee49 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -19,7 +19,7 @@ import logging from peewee import IntegrityError from langfuse import Langfuse from common import settings -from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType +from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, OPENDATALOADER_DEFAULT_CONFIG, OPENDATALOADER_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.services.common_service import CommonService from api.db.services.langfuse_service import TenantLangfuseService @@ -364,6 +364,67 @@ class TenantLLMService(CommonService): idx += 1 continue + @classmethod + def _collect_opendataloader_env_config(cls) -> dict | None: + cfg = dict(OPENDATALOADER_DEFAULT_CONFIG) + found = False + for key in OPENDATALOADER_ENV_KEYS: + val = os.environ.get(key) + if val: + found = True + cfg[key] = val + return cfg if found else None + + @classmethod + @DB.connection_context() + def ensure_opendataloader_from_env(cls, tenant_id: str) -> str | None: + """ + Ensure an OpenDataLoader OCR model exists for the tenant if env variables are present. + Return the existing or newly created llm_name, or None if env not set. + """ + cfg = cls._collect_opendataloader_env_config() + if not cfg: + return None + + saved_models = cls.query(tenant_id=tenant_id, llm_factory="OpenDataLoader", model_type=LLMType.OCR.value) + + def _parse_api_key(raw: str) -> dict: + try: + return json.loads(raw or "{}") + except Exception: + return {} + + for item in saved_models: + api_cfg = _parse_api_key(item.api_key) + normalized = {k: api_cfg.get(k, OPENDATALOADER_DEFAULT_CONFIG.get(k)) for k in OPENDATALOADER_ENV_KEYS} + if normalized == cfg: + return item.llm_name + + used_names = {item.llm_name for item in saved_models} + idx = 1 + base_name = "opendataloader-from-env" + while True: + candidate = f"{base_name}-{idx}" + if candidate in used_names: + idx += 1 + continue + try: + cls.save( + tenant_id=tenant_id, + llm_factory="OpenDataLoader", + llm_name=candidate, + model_type=LLMType.OCR.value, + api_key=json.dumps(cfg), + api_base="", + max_tokens=0, + ) + return candidate + except IntegrityError: + logging.warning("OpenDataLoader env model %s already exists for tenant %s, retry with next name", candidate, tenant_id) + used_names.add(candidate) + idx += 1 + continue + @classmethod @DB.connection_context() def delete_by_tenant_id(cls, tenant_id): diff --git a/common/constants.py b/common/constants.py index b02790863..5d5588845 100644 --- a/common/constants.py +++ b/common/constants.py @@ -260,3 +260,8 @@ PADDLEOCR_DEFAULT_CONFIG = { "PADDLEOCR_ACCESS_TOKEN": None, "PADDLEOCR_ALGORITHM": "PaddleOCR-VL", } + +OPENDATALOADER_ENV_KEYS = ["OPENDATALOADER_APISERVER"] +OPENDATALOADER_DEFAULT_CONFIG = { + "OPENDATALOADER_APISERVER": "", +} diff --git a/conf/llm_factories.json b/conf/llm_factories.json index b5f8a46ed..7ac980851 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -6254,6 +6254,14 @@ "rank": "910", "llm": [] }, + { + "name": "OpenDataLoader", + "logo": "", + "tags": "OCR", + "status": "1", + "rank": "920", + "llm": [] + }, { "name": "n1n", "logo": "", diff --git a/deepdoc/parser/opendataloader_parser.py b/deepdoc/parser/opendataloader_parser.py new file mode 100644 index 000000000..c0e5fa50b --- /dev/null +++ b/deepdoc/parser/opendataloader_parser.py @@ -0,0 +1,431 @@ + +from __future__ import annotations + +import logging +import os +import re +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from os import PathLike +from pathlib import Path +from typing import Any, Callable, Iterable, Optional + +import pdfplumber +import requests +from PIL import Image + +try: + from deepdoc.parser.pdf_parser import RAGFlowPdfParser +except Exception: + class RAGFlowPdfParser: + pass + +from deepdoc.parser.utils import extract_pdf_outlines + + +class OpenDataLoaderContentType(str, Enum): + IMAGE = "image" + TABLE = "table" + TEXT = "text" + EQUATION = "equation" + + +@dataclass +class _BBox: + page_no: int + x0: float + y0: float + x1: float + y1: float + + +_TEXT_TYPES = {"heading", "title", "paragraph", "text", "list", "list_item", "caption"} +_TABLE_TYPES = {"table"} +_IMAGE_TYPES = {"image", "picture", "figure"} +_FORMULA_TYPES = {"formula", "equation"} + + +def _as_float(v) -> Optional[float]: + try: + return float(v) + except Exception: + return None + + +def _bbox_from_element(el: dict) -> Optional[_BBox]: + bb = el.get("bounding box") or el.get("bounding_box") or el.get("bbox") + pn = el.get("page number") + if pn is None: + pn = el.get("page_number") + if pn is None: + pn = el.get("page") + if bb is None or pn is None: + return None + if not isinstance(bb, (list, tuple)) or len(bb) < 4: + return None + coords = [_as_float(x) for x in bb[:4]] + if any(c is None for c in coords): + return None + try: + page_no = int(pn) + except Exception: + return None + # OpenDataLoader emits [left, bottom, right, top] in PDF points. + left, bottom, right, top = coords + x0, x1 = min(left, right), max(left, right) + y0, y1 = min(bottom, top), max(bottom, top) + return _BBox(page_no=page_no, x0=x0, y0=y0, x1=x1, y1=y1) + + +def _iter_elements(node: Any) -> Iterable[dict]: + if isinstance(node, dict): + if "type" in node and ("content" in node or "text" in node or "cells" in node): + yield node + for v in node.values(): + yield from _iter_elements(v) + elif isinstance(node, list): + for item in node: + yield from _iter_elements(item) + + +def _element_text(el: dict) -> str: + content = el.get("content") + if isinstance(content, str): + return content + text = el.get("text") + if isinstance(text, str): + return text + # tables may expose cells; join row-wise if needed + cells = el.get("cells") + if isinstance(cells, list): + rows: dict[int, list[str]] = {} + for c in cells: + if not isinstance(c, dict): + continue + row = c.get("row") or c.get("row_index") or 0 + rows.setdefault(int(row), []).append(str(c.get("content") or c.get("text") or "")) + return "\n".join(" | ".join(v) for _, v in sorted(rows.items())) + return "" + + +def _element_html(el: dict) -> str: + for key in ("html", "html_content"): + v = el.get(key) + if isinstance(v, str) and v.strip(): + return v + return "" + + +class OpenDataLoaderParser(RAGFlowPdfParser): + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + self.page_images: list[Image.Image] = [] + self.page_from = 0 + self.page_to = 10_000 + self.outlines = [] + self.api_url = os.environ.get("OPENDATALOADER_APISERVER", "").rstrip("/") + self.api_key = os.environ.get("OPENDATALOADER_API_KEY", "").strip() + try: + self.timeout = int(os.environ.get("OPENDATALOADER_TIMEOUT", "600") or "600") + except ValueError: + self.logger.warning("[OpenDataLoader] Invalid OPENDATALOADER_TIMEOUT, falling back to 600s") + self.timeout = 600 + + def check_installation(self) -> bool: + """Return True when the OpenDataLoader service is reachable.""" + if not self.api_url: + self.logger.warning( + "[OpenDataLoader] OPENDATALOADER_APISERVER is not set. " + "Start the opendataloader service and set the env var." + ) + return False + try: + headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + resp = requests.get(f"{self.api_url}/health", timeout=5, headers=headers) + if resp.status_code == 200: + return True + self.logger.warning( + f"[OpenDataLoader] Health check returned {resp.status_code}: {resp.text[:200]}" + ) + return False + except Exception as exc: + self.logger.warning(f"[OpenDataLoader] Health check failed: {exc}") + return False + + def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None): + self.page_from = page_from + self.page_to = page_to + bytes_io = None + try: + if not isinstance(fnm, (str, PathLike)): + bytes_io = fnm if isinstance(fnm, BytesIO) else BytesIO(fnm) + opener = pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(bytes_io) + with opener as pdf: + pages = pdf.pages[page_from:page_to] + self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for p in pages] + except Exception as e: + self.page_images = [] + self.logger.exception(e) + finally: + if bytes_io: + bytes_io.close() + + def _make_line_tag(self, bbox: _BBox) -> str: + if bbox is None: + return "" + # Guard: only emit a crop tag when the page was actually rendered. + if not self.page_images or bbox.page_no <= 0 or len(self.page_images) < bbox.page_no: + return "" + x0, x1 = bbox.x0, bbox.x1 + # OpenDataLoader bbox uses PDF coordinate space (origin bottom-left). + # Convert to image-space (origin top-left) by subtracting from page height. + _, page_height = self.page_images[bbox.page_no - 1].size + top = page_height - bbox.y1 + bott = page_height - bbox.y0 + return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format( + bbox.page_no, x0, x1, top, bott + ) + + @staticmethod + def extract_positions(txt: str) -> list[tuple[list[int], float, float, float, float]]: + poss = [] + for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt): + pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t") + left, right, top, bottom = float(left), float(right), float(top), float(bottom) + poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom)) + return poss + + def crop(self, text: str, ZM: int = 1, need_position: bool = False): + if not self.page_images: + return (None, None) if need_position else None + imgs = [] + poss = self.extract_positions(text) + if not poss: + return (None, None) if need_position else None + # Drop positions whose page indices fall outside the rendered range. + max_page = len(self.page_images) - 1 + poss = [p for p in poss if all(0 <= pn <= max_page for pn in p[0])] + if not poss: + return (None, None) if need_position else None + GAP = 6 + pos = poss[0] + poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0))) + pos = poss[-1] + poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1], pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1], pos[4] + 120))) + positions = [] + for ii, (pns, left, right, top, bottom) in enumerate(poss): + if bottom <= top: + bottom = top + 4 + img0 = self.page_images[pns[0]] + x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1])) + crop0 = img0.crop((x0, y0, x1, y1)) + imgs.append(crop0) + if 0 < ii < len(poss) - 1: + positions.append((pns[0] + self.page_from, x0, x1, y0, y1)) + remain_bottom = bottom - img0.size[1] + for pn in pns[1:]: + if remain_bottom <= 0: + break + page = self.page_images[pn] + x0, y0, x1, y1 = int(left), 0, int(right), int(min(remain_bottom, page.size[1])) + cimgp = page.crop((x0, y0, x1, y1)) + imgs.append(cimgp) + if 0 < ii < len(poss) - 1: + positions.append((pn + self.page_from, x0, x1, y0, y1)) + remain_bottom -= page.size[1] + if not imgs: + return (None, None) if need_position else None + height = sum(i.size[1] + GAP for i in imgs) + width = max(i.size[0] for i in imgs) + pic = Image.new("RGB", (width, int(height)), (245, 245, 245)) + h = 0 + for ii, img in enumerate(imgs): + if ii == 0 or ii + 1 == len(imgs): + img = img.convert("RGBA") + overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) + overlay.putalpha(128) + img = Image.alpha_composite(img, overlay).convert("RGB") + pic.paste(img, (0, int(h))) + h += img.size[1] + GAP + return (pic, positions) if need_position else pic + + def _cropout_region(self, bbox: _BBox, zoomin: int = 1): + if not self.page_images: + return None, "" + idx = (bbox.page_no - 1) - self.page_from + if idx < 0 or idx >= len(self.page_images): + return None, "" + page_img = self.page_images[idx] + W, H = page_img.size + x0 = max(0.0, min(float(bbox.x0), W - 1)) + y0 = max(0.0, min(float(H - bbox.y1), H - 1)) + x1 = max(x0 + 1.0, min(float(bbox.x1), W)) + y1 = max(y0 + 1.0, min(float(H - bbox.y0), H)) + try: + crop = page_img.crop((int(x0), int(y0), int(x1), int(y1))).convert("RGB") + except Exception: + return None, "" + pos = (bbox.page_no - 1 if bbox.page_no > 0 else 0, x0, x1, y0, y1) + return crop, [pos] + + def _classify(self, el_type: str) -> str: + t = (el_type or "").lower() + if t in _TABLE_TYPES: + return OpenDataLoaderContentType.TABLE.value + if t in _IMAGE_TYPES: + return OpenDataLoaderContentType.IMAGE.value + if t in _FORMULA_TYPES: + return OpenDataLoaderContentType.EQUATION.value + # Preserve the original structural type (heading, title, paragraph, + # list, caption, …) so downstream parsers can apply heading/title heuristics. + return t if t else OpenDataLoaderContentType.TEXT.value + + def _transfer_from_json(self, root: Any, parse_method: str): + sections: list[tuple[str, ...]] = [] + tables: list = [] + for el in _iter_elements(root): + el_type = self._classify(el.get("type", "")) + bbox = _bbox_from_element(el) + tag = self._make_line_tag(bbox) if bbox else "" + + if el_type == OpenDataLoaderContentType.TABLE.value: + html = _element_html(el) or _element_text(el) + img = None + positions = "" + if bbox: + img, positions = self._cropout_region(bbox) + tables.append(((img, html), positions if positions else "")) + continue + + if el_type == OpenDataLoaderContentType.IMAGE.value: + img = None + positions = "" + if bbox: + img, positions = self._cropout_region(bbox) + caption = _element_text(el) + tables.append(((img, [caption] if caption else [""]), positions if positions else "")) + continue + + text = _element_text(el).strip() + if not text: + continue + if parse_method in {"manual", "pipeline"}: + sections.append((text, el_type, tag)) + elif parse_method == "paper": + sections.append((text + tag, el_type)) + else: + sections.append((text, tag)) + return sections, tables + + @staticmethod + def _sections_from_markdown(md: str, parse_method: str) -> list[tuple[str, ...]]: + txt = (md or "").strip() + if not txt: + return [] + if parse_method in {"manual", "pipeline"}: + return [(txt, OpenDataLoaderContentType.TEXT.value, "")] + if parse_method == "paper": + return [(txt, OpenDataLoaderContentType.TEXT.value)] + return [(txt, "")] + + def parse_pdf( + self, + filepath: str | PathLike[str], + binary: BytesIO | bytes | None = None, + callback: Optional[Callable] = None, + *, + parse_method: str = "raw", + hybrid: Optional[str] = None, + image_output: Optional[str] = None, + sanitize: Optional[bool] = None, + ): + self.outlines = extract_pdf_outlines(binary if binary is not None else filepath) + + if not self.api_url: + raise RuntimeError( + "[OpenDataLoader] OPENDATALOADER_APISERVER is not configured. " + "Please start the opendataloader service and set the env var." + ) + + # Render page images locally — used by _make_line_tag() and crop(). + # The image rendering stays on the RAGFlow host; only the Java conversion + # runs inside the opendataloader service container. + try: + if binary is not None: + src = BytesIO(binary) if isinstance(binary, (bytes, bytearray)) else binary + self.__images__(src, zoomin=1) + else: + self.__images__(str(filepath), zoomin=1) + except Exception as e: + self.logger.warning(f"[OpenDataLoader] render pages failed: {e}") + + # Read PDF bytes for the multipart upload + if binary is not None: + pdf_bytes = binary if isinstance(binary, (bytes, bytearray)) else binary.getvalue() + else: + with open(filepath, "rb") as fh: + pdf_bytes = fh.read() + + filename = Path(str(filepath)).name or "input.pdf" + + if callback: + callback(0.1, f"[OpenDataLoader] Sending '{filename}' to service") + + form_data: dict[str, str] = {} + if hybrid: + form_data["hybrid"] = hybrid + if image_output: + form_data["image_output"] = image_output + if sanitize is not None: + form_data["sanitize"] = "true" if sanitize else "false" + + headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + last_exc: Exception | None = None + for attempt in range(1, 4): + try: + self.logger.info(f"[OpenDataLoader] POST {self.api_url}/file_parse for '{filename}' (attempt {attempt})") + resp = requests.post( + url=f"{self.api_url}/file_parse", + files={"file": (filename, pdf_bytes, "application/pdf")}, + data=form_data, + headers=headers, + timeout=self.timeout, + ) + resp.raise_for_status() + result = resp.json() + break + except Exception as exc: + last_exc = exc + self.logger.warning(f"[OpenDataLoader] attempt {attempt} failed: {exc}") + else: + raise RuntimeError(f"[OpenDataLoader] service call failed after 3 attempts: {last_exc}") from last_exc + + if callback: + callback(0.7, "[OpenDataLoader] Processing response") + + # Service response structure: + # { + # "json_doc": {...} | null, # structured parse tree (preferred) + # "md_text": "..." | null # markdown fallback when json_doc is absent + # } + json_doc = result.get("json_doc") + md_text = result.get("md_text") + + sections: list[tuple[str, ...]] = [] + tables: list = [] + if json_doc is not None: + sections, tables = self._transfer_from_json(json_doc, parse_method=parse_method) + if not sections and md_text: + sections = self._sections_from_markdown(md_text, parse_method=parse_method) + + if callback: + callback(1.0, f"[OpenDataLoader] Done. Sections: {len(sections)}, Tables: {len(tables)}") + + return sections, tables + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + parser = OpenDataLoaderParser() + print("OpenDataLoader service reachable:", parser.check_installation()) diff --git a/docs/guides/dataset/select_pdf_parser.md b/docs/guides/dataset/select_pdf_parser.md index d96992f5a..57eb8b3a6 100644 --- a/docs/guides/dataset/select_pdf_parser.md +++ b/docs/guides/dataset/select_pdf_parser.md @@ -39,6 +39,7 @@ RAGFlow isn't one-size-fits-all. It is built for flexibility and supports deeper - Naive: Skip OCR, TSR, and DLR tasks if _all_ your PDFs are plain text. - [MinerU](https://github.com/opendatalab/MinerU): (Experimental) An open-source tool that converts PDF into machine-readable formats. - [Docling](https://github.com/docling-project/docling): (Experimental) An open-source document processing tool for gen AI. +- [OpenDataLoader](https://github.com/opendataloader-project/opendataloader-pdf): (Experimental) A deterministic, local-first PDF parser with structured JSON + Markdown output. Runs as a standalone service container so no Java runtime is needed on the RAGFlow host. - A third-party visual model from a specific model provider. :::danger IMPORTANT diff --git a/rag/app/naive.py b/rag/app/naive.py index 25b715b6e..b022ec17c 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -169,6 +169,54 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese return sections, tables, pdf_parser +def by_opendataloader( + filename, + binary=None, + from_page=0, + to_page=100000, + lang="Chinese", + callback=None, + pdf_cls=None, + parse_method: str = "raw", + opendataloader_llm_name: str | None = None, + tenant_id: str | None = None, + **kwargs, +): + if tenant_id: + if not opendataloader_llm_name: + try: + from api.db.services.tenant_llm_service import TenantLLMService + + env_name = TenantLLMService.ensure_opendataloader_from_env(tenant_id) + candidates = TenantLLMService.query(tenant_id=tenant_id, llm_factory="OpenDataLoader", model_type=LLMType.OCR) + if candidates: + opendataloader_llm_name = candidates[0].llm_name + elif env_name: + opendataloader_llm_name = env_name + except Exception as e: + logging.warning(f"fallback to env opendataloader: {e}") + + if opendataloader_llm_name: + try: + ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, opendataloader_llm_name) + ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) + pdf_parser = ocr_model.mdl + sections, tables = pdf_parser.parse_pdf( + filepath=filename, + binary=binary, + callback=callback, + parse_method=parse_method, + **kwargs, + ) + return sections, tables, pdf_parser + except Exception as e: + logging.error(f"Failed to parse pdf via LLMBundle OpenDataLoader ({opendataloader_llm_name}): {e}") + + if callback: + callback(-1, "OpenDataLoader not found.") + return None, None, None + + def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None, **kwargs): tcadp_parser = TCADPParser() @@ -255,6 +303,7 @@ PARSERS = { "deepdoc": by_deepdoc, "mineru": by_mineru, "docling": by_docling, + "opendataloader": by_opendataloader, "tcadp parser": by_tcadp, "paddleocr": by_paddleocr, "plaintext": by_plaintext, # default @@ -849,7 +898,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca if table_context_size or image_context_size: tables = append_context2table_image4pdf(sections, tables, image_context_size) - if name in ["tcadp", "docling", "mineru", "paddleocr"]: + if name in ["tcadp", "docling", "mineru", "paddleocr", "opendataloader"]: if int(parser_config.get("chunk_token_num", 0)) <= 0: parser_config["chunk_token_num"] = 0 diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 4583b5226..069ac9b82 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -240,7 +240,7 @@ class ParserParam(ProcessParamBase): pdf_parse_method = pdf_config.get("parse_method", "") self.check_empty(pdf_parse_method, "Parse method abnormal.") - if pdf_parse_method.lower() not in ["deepdoc", "plain_text", "mineru", "docling", "tcadp parser", "paddleocr"]: + if pdf_parse_method.lower() not in ["deepdoc", "plain_text", "mineru", "docling", "opendataloader", "tcadp parser", "paddleocr"]: self.check_empty(pdf_config.get("lang", ""), "PDF VLM language") pdf_output_format = pdf_config.get("output_format", "") @@ -434,6 +434,70 @@ class Parser(ProcessBase): box["image"] = image bboxes.append(box) + elif parse_method.lower() == "opendataloader": + + def resolve_opendataloader_llm_name(): + configured = parser_model_name or conf.get("opendataloader_llm_name") + if configured: + return configured + tenant_id = self._canvas._tenant_id + if not tenant_id: + return None + from api.db.services.tenant_llm_service import TenantLLMService + env_name = TenantLLMService.ensure_opendataloader_from_env(tenant_id) + candidates = TenantLLMService.query(tenant_id=tenant_id, llm_factory="OpenDataLoader", model_type=LLMType.OCR.value) + if candidates: + return candidates[0].llm_name + return env_name + + parser_model_name = resolve_opendataloader_llm_name() + if not parser_model_name: + raise RuntimeError("OpenDataLoader model not configured. Please add OpenDataLoader in Model Providers.") + + tenant_id = self._canvas._tenant_id + ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, parser_model_name) + ocr_model = LLMBundle(tenant_id, ocr_model_config) + pdf_parser = ocr_model.mdl + + lines, odl_tables = pdf_parser.parse_pdf( + filepath=name, + binary=blob, + callback=self.callback, + parse_method="pipeline", + ) + bboxes = [] + for item in lines or []: + if not isinstance(item, tuple) or len(item) < 3: + continue + text, layout_type, poss = item[0], item[1], item[2] + box = { + "text": text, + "layout_type": layout_type or "text", + } + if isinstance(poss, str) and poss: + positions = [[pos[0][-1] + 1, *pos[1:]] for pos in pdf_parser.extract_positions(poss)] + if positions: + box["positions"] = positions + image = pdf_parser.crop(poss, 1) + if image is not None: + box["image"] = image + bboxes.append(box) + # Merge tables and images from the second return value. + for (img, html_or_caption), positions in odl_tables or []: + box = {"layout_type": "table" if not isinstance(html_or_caption, list) else "figure"} + if isinstance(html_or_caption, str): + box["text"] = html_or_caption + elif isinstance(html_or_caption, list): + box["text"] = html_or_caption[0] if html_or_caption else "" + if img is not None: + box["image"] = img + if positions: + try: + box["positions"] = [[p[0] + 1, p[1], p[2], p[3], p[4]] for p in positions] + except Exception: + pass + bboxes.append(box) + elif parse_method.lower() == "tcadp parser": # ADP is a document parsing tool using Tencent Cloud API table_result_type = conf.get("table_result_type", "1") diff --git a/rag/llm/ocr_model.py b/rag/llm/ocr_model.py index 800935467..5a76fe090 100644 --- a/rag/llm/ocr_model.py +++ b/rag/llm/ocr_model.py @@ -19,6 +19,7 @@ import os from typing import Any, Optional from deepdoc.parser.mineru_parser import MinerUParser +from deepdoc.parser.opendataloader_parser import OpenDataLoaderParser from deepdoc.parser.paddleocr_parser import PaddleOCRParser @@ -146,3 +147,59 @@ class PaddleOCROcrModel(Base, PaddleOCRParser): sections, tables = PaddleOCRParser.parse_pdf(self, filepath=filepath, binary=binary, callback=callback, parse_method=parse_method, **kwargs) return sections, tables + + +class OpenDataLoaderOcrModel(Base, OpenDataLoaderParser): + _FACTORY_NAME = "OpenDataLoader" + + def __init__(self, key: str | dict, model_name: str, **kwargs): + Base.__init__(self, key, model_name, **kwargs) + raw_config = {} + if key: + try: + raw_config = json.loads(key) + except Exception: + raw_config = {} + + config = raw_config.get("api_key", raw_config) + if not isinstance(config, dict): + config = {} + + def _resolve_config(key: str, env_key: str, default=""): + return config.get(key, config.get(env_key, os.environ.get(env_key, default))) + + redacted_config = {} + for k, v in config.items(): + if any(s in k.lower() for s in ("key", "password", "token", "secret")): + redacted_config[k] = "[REDACTED]" + else: + redacted_config[k] = v + logging.info(f"Parsed OpenDataLoader config (sensitive fields redacted): {redacted_config}") + + OpenDataLoaderParser.__init__(self) + self.api_url = _resolve_config("opendataloader_apiserver", "OPENDATALOADER_APISERVER", "").rstrip("/") + self.api_key = _resolve_config("opendataloader_api_key", "OPENDATALOADER_API_KEY", "").strip() + timeout_val = _resolve_config("opendataloader_timeout", "OPENDATALOADER_TIMEOUT", "600") or "600" + try: + self.timeout = int(timeout_val) + except (TypeError, ValueError): + self.timeout = 600 + + def check_available(self) -> tuple[bool, str]: + ok = self.check_installation() + return ok, "" if ok else "OpenDataLoader service not reachable" + + def parse_pdf(self, filepath: str, binary=None, callback=None, parse_method: str = "raw", **kwargs): + ok, reason = self.check_available() + if not ok: + raise RuntimeError(f"OpenDataLoader service not accessible: {reason}") + + sections, tables = OpenDataLoaderParser.parse_pdf( + self, + filepath=filepath, + binary=binary, + callback=callback, + parse_method=parse_method, + **kwargs, + ) + return sections, tables diff --git a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py index dea30e68e..8bf9227a5 100644 --- a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py +++ b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py @@ -150,6 +150,10 @@ def _load_llm_app(monkeypatch): def ensure_mineru_from_env(_tenant_id): return None + @staticmethod + def ensure_opendataloader_from_env(_tenant_id): + return None + @staticmethod def query(**_kwargs): return [] @@ -846,6 +850,7 @@ def test_my_llms_include_details_and_exception_unit(monkeypatch): monkeypatch.setattr(module, "request", SimpleNamespace(args={"include_details": "true"})) ensure_calls = [] monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id)) + monkeypatch.setattr(module.TenantLLMService, "ensure_opendataloader_from_env", lambda _tenant_id: None) monkeypatch.setattr( module.TenantLLMService, "query", diff --git a/test/unit_test/deepdoc/parser/test_opendataloader_parser.py b/test/unit_test/deepdoc/parser/test_opendataloader_parser.py new file mode 100644 index 000000000..98416a77c --- /dev/null +++ b/test/unit_test/deepdoc/parser/test_opendataloader_parser.py @@ -0,0 +1,326 @@ +""" +Unit tests for deepdoc/parser/opendataloader_parser.py + +Tests cover the HTTP-client refactoring: check_installation(), parse_pdf(), +and the crop() bounds guard — without requiring a live OpenDataLoader service, +opendataloader_pdf package, or Java runtime. +""" + +from __future__ import annotations + +import importlib.util +import io +import sys +from pathlib import Path +from unittest import mock + +import pytest +import requests + +# --------------------------------------------------------------------------- +# Bootstrap: stub out heavy imports the module pulls in so tests run anywhere +# --------------------------------------------------------------------------- +import types as _types + +# PIL — used only at runtime for image ops, mock the whole package +for _m in ("pdfplumber", "PIL", "PIL.Image"): + if _m not in sys.modules: + sys.modules[_m] = mock.MagicMock() + +# deepdoc.parser.pdf_parser — provide a real base class so OpenDataLoaderParser +# inherits a proper Python class, not a MagicMock (which breaks __init__). +_pdf_parser_mod = _types.ModuleType("deepdoc.parser.pdf_parser") +class _RAGFlowPdfParserStub: # noqa: E302 + pass +_pdf_parser_mod.RAGFlowPdfParser = _RAGFlowPdfParserStub +sys.modules.setdefault("deepdoc.parser.pdf_parser", _pdf_parser_mod) +sys.modules.setdefault("deepdoc", mock.MagicMock()) +sys.modules.setdefault("deepdoc.parser", mock.MagicMock()) + +# deepdoc.parser.utils — extract_pdf_outlines must be a real callable +_utils_mod = _types.ModuleType("deepdoc.parser.utils") +_utils_mod.extract_pdf_outlines = mock.MagicMock(return_value=[]) +sys.modules.setdefault("deepdoc.parser.utils", _utils_mod) + +# Load the module under test +_REPO = Path(__file__).parents[4] +_spec = importlib.util.spec_from_file_location( + "opendataloader_parser", + _REPO / "deepdoc" / "parser" / "opendataloader_parser.py", +) +_mod = importlib.util.module_from_spec(_spec) +# Register before exec so @dataclass can resolve __module__ +sys.modules["opendataloader_parser"] = _mod +_spec.loader.exec_module(_mod) + +OpenDataLoaderParser = _mod.OpenDataLoaderParser + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_parser(api_url: str = "http://odl:9383") -> OpenDataLoaderParser: + p = OpenDataLoaderParser() + p.api_url = api_url + return p + + +def _fake_page_image(width: int = 600, height: int = 800): + img = mock.MagicMock() + img.size = (width, height) + img.crop = mock.MagicMock(return_value=img) + img.convert = mock.MagicMock(return_value=img) + return img + + +# --------------------------------------------------------------------------- +# check_installation() +# --------------------------------------------------------------------------- + +class TestCheckInstallation: + def test_no_api_url_returns_false(self): + p = OpenDataLoaderParser() + p.api_url = "" + assert p.check_installation() is False + + def test_health_200_returns_true(self): + p = _make_parser() + resp = mock.MagicMock(status_code=200) + with mock.patch("requests.get", return_value=resp): + assert p.check_installation() is True + + def test_health_503_returns_false(self): + p = _make_parser() + resp = mock.MagicMock(status_code=503, text="unavailable") + with mock.patch("requests.get", return_value=resp): + assert p.check_installation() is False + + def test_connection_error_returns_false(self): + p = _make_parser() + with mock.patch("requests.get", side_effect=requests.ConnectionError("refused")): + assert p.check_installation() is False + + +# --------------------------------------------------------------------------- +# parse_pdf() +# --------------------------------------------------------------------------- + +class TestParsePdf: + def _mock_response(self, json_doc=None, md_text=None) -> mock.MagicMock: + resp = mock.MagicMock() + resp.raise_for_status = mock.MagicMock() + resp.json.return_value = {"json_doc": json_doc, "md_text": md_text} + return resp + + def test_raises_when_api_url_not_set(self, tmp_path): + p = OpenDataLoaderParser() + p.api_url = "" + pdf = tmp_path / "doc.pdf" + pdf.write_bytes(b"%PDF-dummy") + with pytest.raises(RuntimeError, match="OPENDATALOADER_APISERVER"): + p.parse_pdf(filepath=str(pdf)) + + def test_posts_to_file_parse_endpoint(self, tmp_path): + p = _make_parser() + pdf = tmp_path / "doc.pdf" + pdf.write_bytes(b"%PDF-dummy") + resp = self._mock_response(md_text="hello world") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath=str(pdf)) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert "/file_parse" in call_kwargs.kwargs.get("url", call_kwargs.args[0] if call_kwargs.args else "") + + def test_binary_bytes_sent_as_multipart(self, tmp_path): + p = _make_parser() + pdf_bytes = b"%PDF-binary" + resp = self._mock_response(md_text="section text") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="file.pdf", binary=pdf_bytes) + + files_arg = mock_post.call_args.kwargs.get("files", {}) + assert "file" in files_arg + _, sent_bytes, mime = files_arg["file"] + assert sent_bytes == pdf_bytes + assert mime == "application/pdf" + + def test_bytesio_binary_sent_correctly(self, tmp_path): + p = _make_parser() + pdf_bytes = b"%PDF-bytesio" + resp = self._mock_response(md_text="text from bytesio") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="file.pdf", binary=io.BytesIO(pdf_bytes)) + + files_arg = mock_post.call_args.kwargs.get("files", {}) + _, sent_bytes, _ = files_arg["file"] + assert sent_bytes == pdf_bytes + + def test_json_doc_response_returns_sections(self, tmp_path): + p = _make_parser() + json_doc = { + "type": "paragraph", + "content": "Hello from JSON", + "page_number": 1, + "bounding_box": [0, 0, 100, 20], + } + resp = self._mock_response(json_doc=json_doc) + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp): + sections, tables = p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", parse_method="pipeline") + + assert any("Hello from JSON" in s[0] for s in sections) + + def test_md_text_fallback_when_no_json(self, tmp_path): + p = _make_parser() + resp = self._mock_response(json_doc=None, md_text="# Markdown heading\n\nBody text.") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp): + sections, tables = p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", parse_method="pipeline") + + assert len(sections) > 0 + assert tables == [] + + def test_sanitize_true_sends_string_true(self): + p = _make_parser() + resp = self._mock_response(md_text="ok") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", sanitize=True) + + data_arg = mock_post.call_args.kwargs.get("data", {}) + assert data_arg.get("sanitize") == "true" + + def test_sanitize_false_sends_string_false(self): + p = _make_parser() + resp = self._mock_response(md_text="ok") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", sanitize=False) + + data_arg = mock_post.call_args.kwargs.get("data", {}) + assert data_arg.get("sanitize") == "false" + + def test_hybrid_and_image_output_forwarded(self): + p = _make_parser() + resp = self._mock_response(md_text="ok") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", + hybrid="docling-fast", image_output="embedded") + + data_arg = mock_post.call_args.kwargs.get("data", {}) + assert data_arg.get("hybrid") == "docling-fast" + assert data_arg.get("image_output") == "embedded" + + def test_optional_params_omitted_when_none(self): + p = _make_parser() + resp = self._mock_response(md_text="ok") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF") + + data_arg = mock_post.call_args.kwargs.get("data", {}) + assert "hybrid" not in data_arg + assert "image_output" not in data_arg + assert "sanitize" not in data_arg + + def test_callback_called_at_progress_points(self): + p = _make_parser() + resp = self._mock_response(md_text="text") + cb = mock.MagicMock() + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp): + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", callback=cb) + + progress_values = [call.args[0] for call in cb.call_args_list] + assert 0.1 in progress_values + assert 1.0 in progress_values + + def test_http_error_raises_runtime_error(self): + p = _make_parser() + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", side_effect=requests.ConnectionError("down")): + with pytest.raises(RuntimeError, match="service call failed"): + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF") + + def test_non_200_status_raises_runtime_error(self): + p = _make_parser() + resp = mock.MagicMock() + resp.raise_for_status.side_effect = requests.HTTPError("500 Server Error") + + with mock.patch.object(p, "__images__"), \ + mock.patch("requests.post", return_value=resp): + with pytest.raises(RuntimeError, match="service call failed"): + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF") + + +# --------------------------------------------------------------------------- +# crop() — bounds guard +# --------------------------------------------------------------------------- + +class TestCrop: + def test_returns_none_when_no_page_images(self): + p = _make_parser() + p.page_images = [] + result = p.crop("@@1\t10.0\t100.0\t20.0\t80.0##") + assert result is None + + def test_returns_none_when_no_position_tags(self): + p = _make_parser() + p.page_images = [_fake_page_image()] + result = p.crop("no tags here") + assert result is None + + def test_out_of_range_page_index_filtered_returns_none(self): + p = _make_parser() + # Only 1 page rendered (index 0), but tag references page 5 (index 4) + p.page_images = [_fake_page_image()] + # Tag: page 5 → extract_positions returns pn=[4] + tag = "@@5\t10.0\t100.0\t20.0\t80.0##" + result = p.crop(tag) + assert result is None + + def test_valid_page_index_does_not_raise(self): + p = _make_parser() + img = _fake_page_image(width=200, height=300) + p.page_images = [img, img, img] + # Tag references page 2 (index 1) — within rendered range. + # Patch Image.new and alpha_composite at the module level to avoid + # real ImagingCore requirements from mocked PIL images. + tag = "@@2\t10.0\t100.0\t20.0\t80.0##" + canvas = mock.MagicMock() + canvas.paste = mock.MagicMock() + try: + with mock.patch.object(_mod.Image, "new", return_value=canvas), \ + mock.patch.object(_mod.Image, "alpha_composite", return_value=img): + p.crop(tag) + except IndexError: + pytest.fail("crop() raised IndexError for a valid page index") + + def test_need_position_false_returns_image_or_none(self): + p = _make_parser() + p.page_images = [] + result = p.crop("@@1\t10.0\t100.0\t20.0\t80.0##", need_position=False) + assert result is None + + def test_need_position_true_returns_tuple_when_no_images(self): + p = _make_parser() + p.page_images = [] + result = p.crop("@@1\t10.0\t100.0\t20.0\t80.0##", need_position=True) + assert result == (None, None) diff --git a/web/src/components/layout-recognize-form-field.tsx b/web/src/components/layout-recognize-form-field.tsx index 7b6a077fb..8ab908917 100644 --- a/web/src/components/layout-recognize-form-field.tsx +++ b/web/src/components/layout-recognize-form-field.tsx @@ -20,6 +20,7 @@ export const enum ParseDocumentType { DeepDOC = 'DeepDOC', PlainText = 'Plain Text', Docling = 'Docling', + OpenDataLoader = 'OpenDataLoader', TCADPParser = 'TCADP Parser', } @@ -52,6 +53,7 @@ export function LayoutRecognizeFormField({ ParseDocumentType.DeepDOC, ParseDocumentType.PlainText, ParseDocumentType.Docling, + ParseDocumentType.OpenDataLoader, ParseDocumentType.TCADPParser, ].map((x) => ({ label: x === ParseDocumentType.PlainText ? t(camelCase(x)) : x, diff --git a/web/src/constants/llm.ts b/web/src/constants/llm.ts index 52c1a1d7d..17fcc0620 100644 --- a/web/src/constants/llm.ts +++ b/web/src/constants/llm.ts @@ -62,6 +62,7 @@ export enum LLMFactory { Builtin = 'Builtin', MinerU = 'MinerU', PaddleOCR = 'PaddleOCR', + OpenDataLoader = 'OpenDataLoader', N1n = 'n1n', Avian = 'Avian', RAGcon = 'RAGcon', diff --git a/web/src/pages/user-setting/setting-model/hooks.tsx b/web/src/pages/user-setting/setting-model/hooks.tsx index fe233e057..47cfaa37c 100644 --- a/web/src/pages/user-setting/setting-model/hooks.tsx +++ b/web/src/pages/user-setting/setting-model/hooks.tsx @@ -807,6 +807,56 @@ export const useSubmitPaddleOCR = () => { }; }; +export const useSubmitOpenDataLoader = () => { + const [saveLoading, setSaveLoading] = useState(false); + const { addLlm } = useAddLlm(); + const { + visible: opendataloaderVisible, + hideModal: hideOpenDataLoaderModal, + showModal: showOpenDataLoaderModal, + } = useSetModalState(); + + const onOpenDataLoaderOk = useCallback( + async (payload: any, isVerify = false) => { + if (!isVerify) { + setSaveLoading(true); + } + const req: IAddLlmRequestBody = { + llm_factory: LLMFactory.OpenDataLoader, + llm_name: payload.llm_name, + model_type: 'ocr', + api_key: { ...payload }, + api_base: '', + max_tokens: 0, + }; + const ret = await addLlm({ ...req, verify: isVerify }); + if (!isVerify) { + setSaveLoading(false); + if (ret.code === 0) { + hideOpenDataLoaderModal(); + return true; + } + } + if (isVerify) { + return { + isValid: !!ret.data?.success, + logs: ret.data?.message, + } as VerifyResult; + } + return false; + }, + [addLlm, hideOpenDataLoaderModal, setSaveLoading], + ); + + return { + opendataloaderVisible, + hideOpenDataLoaderModal, + showOpenDataLoaderModal, + onOpenDataLoaderOk, + opendataloaderLoading: saveLoading, + }; +}; + export const useVerifySettings = ({ onVerify, }: { diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 0ca84b142..39f490feb 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -14,6 +14,7 @@ import { useSubmitGoogle, useSubmitMinerU, useSubmitOllama, + useSubmitOpenDataLoader, useSubmitPaddleOCR, useSubmitSpark, useSubmitSystemModelSetting, @@ -30,6 +31,7 @@ import GoogleModal from './modal/google-modal'; import MinerUModal from './modal/mineru-modal'; import TencentCloudModal from './modal/next-tencent-modal'; import OllamaModal from './modal/ollama-modal'; +import OpenDataLoaderModal from './modal/opendataloader-modal'; import PaddleOCRModal from './modal/paddleocr-modal'; import SparkModal from './modal/spark-modal'; import VolcEngineModal from './modal/volcengine-modal'; @@ -139,6 +141,14 @@ const ModelProviders = () => { paddleocrLoading, } = useSubmitPaddleOCR(); + const { + opendataloaderVisible, + hideOpenDataLoaderModal, + showOpenDataLoaderModal, + onOpenDataLoaderOk, + opendataloaderLoading, + } = useSubmitOpenDataLoader(); + const ModalMap = useMemo( () => ({ [LLMFactory.Bedrock]: showBedrockAddingModal, @@ -151,6 +161,7 @@ const ModelProviders = () => { [LLMFactory.AzureOpenAI]: showAzureAddingModal, [LLMFactory.MinerU]: showMineruModal, [LLMFactory.PaddleOCR]: showPaddleOCRModal, + [LLMFactory.OpenDataLoader]: showOpenDataLoaderModal, }), [ showBedrockAddingModal, @@ -163,6 +174,7 @@ const ModelProviders = () => { showAzureAddingModal, showMineruModal, showPaddleOCRModal, + showOpenDataLoaderModal, ], ); @@ -240,6 +252,9 @@ const ModelProviders = () => { if (paddleocrVisible) { return onPaddleOCROk; } + if (opendataloaderVisible) { + return onOpenDataLoaderOk; + } if (GoogleAddingVisible) { return onGoogleAddingOk; } @@ -269,6 +284,8 @@ const ModelProviders = () => { onMineruOk, paddleocrVisible, onPaddleOCROk, + opendataloaderVisible, + onOpenDataLoaderOk, ]); const { onApiKeyVerifying } = useVerifySettings({ @@ -391,6 +408,13 @@ const ModelProviders = () => { loading={paddleocrLoading} onVerify={onApiKeyVerifying} > + ); }; diff --git a/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx new file mode 100644 index 000000000..8d9421917 --- /dev/null +++ b/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx @@ -0,0 +1,137 @@ +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Button, ButtonLoading } from '@/components/ui/button'; +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog'; +import { Form } from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { LLMFactory } from '@/constants/llm'; +import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { memo, useMemo } from 'react'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { LLMHeader } from '../../components/llm-header'; +import VerifyButton from '../verify-button'; + +export type OpenDataLoaderFormValues = { + llm_name: string; + opendataloader_apiserver: string; + opendataloader_api_key?: string; +}; + +export interface IModalProps { + visible: boolean; + hideModal: () => void; + onOk?: (data: T) => Promise; + onVerify?: ( + postBody: any, + ) => Promise; + loading?: boolean; +} + +const OpenDataLoaderModal = ({ + visible, + hideModal, + onOk, + onVerify, + loading, +}: IModalProps) => { + const { t } = useTranslation(); + + const FormSchema = useMemo( + () => + z.object({ + llm_name: z.string().min(1, { + message: t('setting.modelNameMessage'), + }), + opendataloader_apiserver: z.string().min(1, { + message: t('setting.apiServerMessage'), + }), + opendataloader_api_key: z.string().optional(), + }), + [t], + ); + + const form = useForm({ + resolver: zodResolver(FormSchema), + defaultValues: { + opendataloader_apiserver: '', + opendataloader_api_key: '', + }, + }); + + const handleOk = async (values: OpenDataLoaderFormValues) => { + const ret = await onOk?.(values as any); + if (ret) { + hideModal?.(); + } + }; + + return ( + + + + + + + +
+ + + + + + + + + + + {onVerify && ( + Promise} + /> + )} + + + + + + {t('common.add')} + + +
+
+ ); +}; + +export default memo(OpenDataLoaderModal);