mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-19 11:45:10 +08:00
Compare commits
5 Commits
59f4c51222
...
59075a0b58
| Author | SHA1 | Date | |
|---|---|---|---|
| 59075a0b58 | |||
| 30bd25716b | |||
| 99dae3c64c | |||
| 045314a1aa | |||
| 2b20d0b3bb |
@ -16,21 +16,23 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from quart import Blueprint, Quart, request, g, current_app, session, jsonify
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from quart_cors import cors
|
||||
from common.constants import StatusEnum
|
||||
from common.constants import StatusEnum, RetCode
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from quart_auth import Unauthorized
|
||||
from quart_auth import Unauthorized as QuartAuthUnauthorized
|
||||
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
|
||||
from quart_schema import QuartSchema
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.utils.api_utils import server_error_response, get_json_result
|
||||
from api.constants import API_VERSION
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
@ -38,6 +40,22 @@ settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
|
||||
|
||||
|
||||
def _unauthorized_message(error):
|
||||
if error is None:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
try:
|
||||
msg = repr(error)
|
||||
except Exception:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
if msg == UNAUTHORIZED_MESSAGE:
|
||||
return msg
|
||||
if "Unauthorized" in msg and "401" in msg:
|
||||
return msg
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
|
||||
@ -145,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not current_user: # or not session.get("_user_id"):
|
||||
raise Unauthorized()
|
||||
else:
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||
t_start = time.perf_counter() if timing_enabled else None
|
||||
user = current_user
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing login_required auth_ms=%.2f path=%s",
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
if not user: # or not session.get("_user_id"):
|
||||
raise QuartAuthUnauthorized()
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -258,12 +284,33 @@ client_urls_prefix = [
|
||||
|
||||
@app.errorhandler(404)
|
||||
async def not_found(error):
|
||||
error_msg: str = f"The requested URL {request.path} was not found"
|
||||
logging.error(error_msg)
|
||||
return {
|
||||
logging.error(f"The requested URL {request.path} was not found")
|
||||
message = f"Not Found: {request.path}"
|
||||
response = {
|
||||
"code": RetCode.NOT_FOUND,
|
||||
"message": message,
|
||||
"data": None,
|
||||
"error": "Not Found",
|
||||
"message": error_msg,
|
||||
}, 404
|
||||
}
|
||||
return jsonify(response), RetCode.NOT_FOUND
|
||||
|
||||
|
||||
@app.errorhandler(401)
|
||||
async def unauthorized(error):
|
||||
logging.warning("Unauthorized request")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||
|
||||
|
||||
@app.errorhandler(QuartAuthUnauthorized)
|
||||
async def unauthorized_quart_auth(error):
|
||||
logging.warning("Unauthorized request (quart_auth)")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
|
||||
|
||||
|
||||
@app.errorhandler(WerkzeugUnauthorized)
|
||||
async def unauthorized_werkzeug(error):
|
||||
logging.warning("Unauthorized request (werkzeug)")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exception):
|
||||
|
||||
@ -126,10 +126,15 @@ def get():
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
async def set():
|
||||
req = await get_request_json()
|
||||
content_with_weight = req["content_with_weight"]
|
||||
if not isinstance(content_with_weight, (str, bytes)):
|
||||
raise TypeError("expected string or bytes-like object")
|
||||
if isinstance(content_with_weight, bytes):
|
||||
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
"content_with_weight": content_with_weight}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if "important_kwd" in req:
|
||||
if not isinstance(req["important_kwd"], list):
|
||||
@ -171,7 +176,7 @@ async def set():
|
||||
_d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||
@ -223,14 +228,27 @@ async def rm():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
# Include doc_id in condition to properly scope the delete
|
||||
condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
|
||||
if not settings.docStoreConn.delete(condition,
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
try:
|
||||
deleted_count = settings.docStoreConn.delete(condition,
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id)
|
||||
except Exception:
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
if isinstance(deleted_chunk_ids, list):
|
||||
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
|
||||
has_ids = len(unique_chunk_ids) > 0
|
||||
else:
|
||||
unique_chunk_ids = [deleted_chunk_ids]
|
||||
has_ids = deleted_chunk_ids not in (None, "")
|
||||
if has_ids and deleted_count == 0:
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
|
||||
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id)
|
||||
chunk_number = deleted_count
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
|
||||
@ -42,13 +42,18 @@ async def set_dialog():
|
||||
if len(name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
||||
|
||||
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
|
||||
name = name.strip()
|
||||
name = duplicate_name(
|
||||
DialogService.query,
|
||||
name=name,
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
name = name.strip()
|
||||
if is_create:
|
||||
existing_names = {
|
||||
d.name.casefold()
|
||||
for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
if d.name
|
||||
}
|
||||
if name.casefold() in existing_names:
|
||||
def _name_exists(name: str, **_kwargs) -> bool:
|
||||
return name.casefold() in existing_names
|
||||
|
||||
name = duplicate_name(_name_exists, name=name)
|
||||
|
||||
description = req.get("description", "A helpful dialog")
|
||||
icon = req.get("icon", "")
|
||||
@ -63,16 +68,15 @@ async def set_dialog():
|
||||
meta_data_filter = req.get("meta_data_filter", {})
|
||||
prompt_config = req["prompt_config"]
|
||||
|
||||
if not is_create:
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
|
||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(
|
||||
message="Parameter '{}' is not used".format(p["key"]))
|
||||
for p in prompt_config.get("parameters", []):
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(
|
||||
message="Parameter '{}' is not used".format(p["key"]))
|
||||
|
||||
try:
|
||||
e, tenant = TenantService.get_by_id(current_user.id)
|
||||
|
||||
@ -62,10 +62,21 @@ async def upload():
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = files.getlist("file")
|
||||
def _close_file_objs(objs):
|
||||
for obj in objs:
|
||||
try:
|
||||
obj.close()
|
||||
except Exception:
|
||||
try:
|
||||
obj.stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
_close_file_objs(file_objs)
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
_close_file_objs(file_objs)
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
@ -35,22 +37,56 @@ from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
@login_required
|
||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||
async def create_memory():
|
||||
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||
t_start = time.perf_counter() if timing_enabled else None
|
||||
req = await get_request_json()
|
||||
t_parsed = time.perf_counter() if timing_enabled else None
|
||||
# check name length
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
# check memory_type valid
|
||||
if not isinstance(req["memory_type"], list):
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result("Memory type must be a list.")
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
memory_type = list(memory_type)
|
||||
|
||||
try:
|
||||
t_before_db = time.perf_counter() if timing_enabled else None
|
||||
res, memory = MemoryService.create_memory(
|
||||
tenant_id=current_user.id,
|
||||
name=memory_name,
|
||||
@ -58,6 +94,15 @@ async def create_memory():
|
||||
embd_id=req["embd_id"],
|
||||
llm_id=req["llm_id"]
|
||||
)
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(t_before_db - t_parsed) * 1000,
|
||||
(time.perf_counter() - t_before_db) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
|
||||
if res:
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
@ -445,6 +445,7 @@ class DocumentService(CommonService):
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
|
||||
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
|
||||
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
||||
return list(docs.dicts())
|
||||
@ -936,6 +937,8 @@ class DocumentService(CommonService):
|
||||
bad = 0
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
if status == TaskStatus.CANCEL.value:
|
||||
continue
|
||||
doc_progress = doc.progress if doc and doc.progress else 0.0
|
||||
special_task_running = False
|
||||
priority = 0
|
||||
@ -979,7 +982,16 @@ class DocumentService(CommonService):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
cls.update_by_id(d["id"], info)
|
||||
info["update_time"] = current_timestamp()
|
||||
info["update_date"] = get_format_time()
|
||||
(
|
||||
cls.model.update(info)
|
||||
.where(
|
||||
(cls.model.id == d["id"])
|
||||
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
except Exception as e:
|
||||
if str(e).find("'0'") < 0:
|
||||
logging.exception("fetch task exception")
|
||||
@ -1012,7 +1024,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||
# cancelled: run == "2" but progress can vary
|
||||
# cancelled: run == "2"
|
||||
cancelled = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||
|
||||
@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService):
|
||||
if dataset_name == "":
|
||||
return False, get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
|
||||
|
||||
# Deduplicate name within tenant
|
||||
dataset_name = duplicate_name(
|
||||
|
||||
@ -31,6 +31,12 @@ from quart import (
|
||||
jsonify,
|
||||
request
|
||||
)
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||
|
||||
try:
|
||||
from quart.exceptions import BadRequest as QuartBadRequest
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
QuartBadRequest = None
|
||||
|
||||
from peewee import OperationalError
|
||||
|
||||
@ -48,35 +54,33 @@ requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSON
|
||||
|
||||
async def _coerce_request_data() -> dict:
|
||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||
if hasattr(request, "_cached_payload"):
|
||||
return request._cached_payload
|
||||
payload: Any = None
|
||||
last_error: Exception | None = None
|
||||
|
||||
try:
|
||||
payload = await request.get_json(force=True, silent=True)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
body_bytes = await request.get_data()
|
||||
has_body = bool(body_bytes)
|
||||
content_type = (request.content_type or "").lower()
|
||||
is_json = content_type.startswith("application/json")
|
||||
|
||||
if payload is None:
|
||||
try:
|
||||
form = await request.form
|
||||
payload = form.to_dict()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
if not has_body:
|
||||
payload = {}
|
||||
elif is_json:
|
||||
payload = await request.get_json(force=False, silent=False)
|
||||
if isinstance(payload, dict):
|
||||
payload = payload or {}
|
||||
elif isinstance(payload, str):
|
||||
raise AttributeError("'str' object has no attribute 'get'")
|
||||
else:
|
||||
raise TypeError("JSON payload must be an object.")
|
||||
else:
|
||||
form = await request.form
|
||||
payload = form.to_dict() if form else None
|
||||
if payload is None:
|
||||
raise TypeError("Request body is not a valid form payload.")
|
||||
|
||||
if payload is None:
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise ValueError("No JSON body or form data found in request.")
|
||||
|
||||
if isinstance(payload, dict):
|
||||
return payload or {}
|
||||
|
||||
if isinstance(payload, str):
|
||||
raise AttributeError("'str' object has no attribute 'get'")
|
||||
|
||||
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
|
||||
request._cached_payload = payload
|
||||
return payload
|
||||
|
||||
async def get_request_json():
|
||||
return await _coerce_request_data()
|
||||
@ -124,16 +128,12 @@ def server_error_response(e):
|
||||
try:
|
||||
msg = repr(e).lower()
|
||||
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
||||
resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
|
||||
resp.status_code = RetCode.UNAUTHORIZED
|
||||
return resp
|
||||
except Exception as ex:
|
||||
logging.warning(f"error checking authorization: {ex}")
|
||||
|
||||
if len(e.args) > 1:
|
||||
try:
|
||||
serialized_data = serialize_for_json(e.args[1])
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
|
||||
except Exception:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||
|
||||
@ -168,7 +168,17 @@ def validate_request(*args, **kwargs):
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
async def decorated_function(*_args, **_kwargs):
|
||||
errs = process_args(await _coerce_request_data())
|
||||
exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
|
||||
if QuartBadRequest is not None:
|
||||
exception_types = exception_types + (QuartBadRequest,)
|
||||
if args or kwargs:
|
||||
try:
|
||||
input_arguments = await _coerce_request_data()
|
||||
except exception_types:
|
||||
input_arguments = {}
|
||||
else:
|
||||
input_arguments = await _coerce_request_data()
|
||||
errs = process_args(input_arguments)
|
||||
if errs:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@ -476,7 +476,7 @@ class RAGFlowPdfParser:
|
||||
self.boxes = bxs
|
||||
|
||||
def _naive_vertical_merge(self, zoomin=3):
|
||||
#bxs = self._assign_column(self.boxes, zoomin)
|
||||
# bxs = self._assign_column(self.boxes, zoomin)
|
||||
bxs = self.boxes
|
||||
|
||||
grouped = defaultdict(list)
|
||||
@ -553,7 +553,8 @@ class RAGFlowPdfParser:
|
||||
|
||||
merged_boxes.extend(bxs)
|
||||
|
||||
#self.boxes = sorted(merged_boxes, key=lambda x: (x["page_number"], x.get("col_id", 0), x["top"]))
|
||||
# self.boxes = sorted(merged_boxes, key=lambda x: (x["page_number"], x.get("col_id", 0), x["top"]))
|
||||
self.boxes = merged_boxes
|
||||
|
||||
def _final_reading_order_merge(self, zoomin=3):
|
||||
if not self.boxes:
|
||||
|
||||
@ -60,6 +60,12 @@ class Chat(Base):
|
||||
super().__init__(rag, res_dict)
|
||||
|
||||
def update(self, update_message: dict):
|
||||
if not isinstance(update_message, dict):
|
||||
raise Exception("ValueError('`update_message` must be a dict')")
|
||||
if update_message.get("llm") == {}:
|
||||
raise Exception("ValueError('`llm` cannot be empty')")
|
||||
if update_message.get("prompt") == {}:
|
||||
raise Exception("ValueError('`prompt` cannot be empty')")
|
||||
res = self.put(f"/chats/{self.id}", update_message)
|
||||
res = res.json()
|
||||
if res.get("code") != 0:
|
||||
|
||||
@ -318,6 +318,8 @@ class RAGFlow:
|
||||
for data in res["data"]["memory_list"]:
|
||||
result_list.append(Memory(self, data))
|
||||
return {
|
||||
"code": res.get("code", 0),
|
||||
"message": res.get("message"),
|
||||
"memory_list": result_list,
|
||||
"total_count": res["data"]["total_count"]
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ def batch_create_datasets(auth, num):
|
||||
|
||||
|
||||
# DOCUMENT APP
|
||||
def upload_documents(auth, payload=None, files_path=None):
|
||||
def upload_documents(auth, payload=None, files_path=None, *, filename_override=None):
|
||||
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
|
||||
|
||||
if files_path is None:
|
||||
@ -115,7 +115,8 @@ def upload_documents(auth, payload=None, files_path=None):
|
||||
for fp in files_path:
|
||||
p = Path(fp)
|
||||
f = p.open("rb")
|
||||
fields.append(("file", (p.name, f)))
|
||||
filename = filename_override if filename_override is not None else p.name
|
||||
fields.append(("file", (filename, f)))
|
||||
file_objects.append(f)
|
||||
m = MultipartEncoder(fields=fields)
|
||||
|
||||
|
||||
@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from time import sleep
|
||||
|
||||
from ragflow_sdk import RAGFlow
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
import pytest
|
||||
from common import (
|
||||
batch_add_chunks,
|
||||
@ -81,7 +82,9 @@ def generate_test_files(request: FixtureRequest, tmp_path):
|
||||
def ragflow_tmp_dir(request, tmp_path_factory):
|
||||
class_name = request.cls.__name__
|
||||
return tmp_path_factory.mktemp(class_name)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(token: str) -> RAGFlow:
|
||||
return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def WebApiAuth(auth):
|
||||
|
||||
@ -265,11 +265,11 @@ class TestChunksRetrieval:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_highlight, expected_message",
|
||||
[
|
||||
({"highlight": True}, 0, True, ""),
|
||||
({"highlight": "True"}, 0, True, ""),
|
||||
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
pytest.param({"highlight": True}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
|
||||
pytest.param({"highlight": "True"}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
|
||||
({"highlight": False}, 0, False, ""),
|
||||
({"highlight": "False"}, 0, False, ""),
|
||||
({"highlight": None}, 0, False, "")
|
||||
],
|
||||
)
|
||||
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):
|
||||
|
||||
@ -17,11 +17,9 @@ import string
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from common import DOCUMENT_APP_URL, list_kbs, upload_documents
|
||||
from configs import DOCUMENT_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
|
||||
from common import list_kbs, upload_documents
|
||||
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from requests_toolbelt import MultipartEncoder
|
||||
from utils.file_utils import create_txt_file
|
||||
|
||||
|
||||
@ -111,17 +109,9 @@ class TestDocumentsUpload:
|
||||
kb_id = add_dataset_func
|
||||
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
|
||||
fields = [("file", ("", fp.open("rb"))), ("kb_id", kb_id)]
|
||||
m = MultipartEncoder(fields=fields)
|
||||
res = requests.post(
|
||||
url=url,
|
||||
headers={"Content-Type": m.content_type},
|
||||
auth=WebApiAuth,
|
||||
data=m,
|
||||
)
|
||||
assert res.json()["code"] == 101, res
|
||||
assert res.json()["message"] == "No file selected!", res
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp], filename_override="")
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == "No file selected!", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filename_exceeds_max_length(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
|
||||
@ -53,10 +53,7 @@ module.exports = {
|
||||
],
|
||||
},
|
||||
],
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
'react-refresh/only-export-components': 'off',
|
||||
'no-console': ['warn', { allow: ['warn', 'error'] }],
|
||||
'check-file/filename-naming-convention': [
|
||||
'error',
|
||||
|
||||
@ -1016,10 +1016,10 @@ export const initialPDFGeneratorValues = {
|
||||
watermark_text: '',
|
||||
enable_toc: false,
|
||||
outputs: {
|
||||
file_path: { type: 'string', value: '' },
|
||||
pdf_base64: { type: 'string', value: '' },
|
||||
download: { type: 'string', value: '' },
|
||||
success: { type: 'boolean', value: false },
|
||||
file_path: { type: 'string' },
|
||||
pdf_base64: { type: 'string' },
|
||||
download: { type: 'string' },
|
||||
success: { type: 'boolean' },
|
||||
},
|
||||
};
|
||||
|
||||
@ -1075,3 +1075,13 @@ export enum WebhookStatus {
|
||||
Live = 'live',
|
||||
Stopped = 'stopped',
|
||||
}
|
||||
|
||||
// Map BeginQueryType to TypesWithArray
|
||||
export const BeginQueryTypeMap = {
|
||||
[BeginQueryType.Line]: TypesWithArray.String,
|
||||
[BeginQueryType.Paragraph]: TypesWithArray.String,
|
||||
[BeginQueryType.Options]: TypesWithArray.ArrayString,
|
||||
[BeginQueryType.File]: 'File',
|
||||
[BeginQueryType.Integer]: TypesWithArray.Number,
|
||||
[BeginQueryType.Boolean]: TypesWithArray.Boolean,
|
||||
};
|
||||
|
||||
@ -96,7 +96,7 @@ function ParameterForm({
|
||||
},
|
||||
[],
|
||||
);
|
||||
}, []);
|
||||
}, [t]);
|
||||
|
||||
const type = useWatch({
|
||||
control: form.control,
|
||||
|
||||
@ -14,7 +14,10 @@ type OutputProps = {
|
||||
isFormRequired?: boolean;
|
||||
} & PropsWithChildren;
|
||||
|
||||
export function transferOutputs(outputs: Record<string, any>) {
|
||||
export function transferOutputs(outputs: Record<string, any> | undefined) {
|
||||
if (!outputs) {
|
||||
return [];
|
||||
}
|
||||
return Object.entries(outputs).map(([key, value]) => ({
|
||||
title: key,
|
||||
type: value?.type,
|
||||
@ -35,7 +38,7 @@ export function Output({
|
||||
<div className="text-sm flex items-center justify-between">
|
||||
{t('flow.output')} <span>{children}</span>
|
||||
</div>
|
||||
<ul>
|
||||
<ul className="space-y-1">
|
||||
{list.map((x, idx) => (
|
||||
<li
|
||||
key={idx}
|
||||
|
||||
@ -64,13 +64,12 @@ function PDFGeneratorForm({ node }: INextOperatorForm) {
|
||||
add_timestamp: z.boolean(),
|
||||
watermark_text: z.string().optional(),
|
||||
enable_toc: z.boolean(),
|
||||
outputs: z
|
||||
.object({
|
||||
file_path: z.object({ type: z.string() }),
|
||||
pdf_base64: z.object({ type: z.string() }),
|
||||
success: z.object({ type: z.string() }),
|
||||
})
|
||||
.optional(),
|
||||
outputs: z.object({
|
||||
file_path: z.object({ type: z.string() }),
|
||||
pdf_base64: z.object({ type: z.string() }),
|
||||
download: z.object({ type: z.string() }),
|
||||
success: z.object({ type: z.string() }),
|
||||
}),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof FormSchema>>({
|
||||
@ -78,9 +77,11 @@ function PDFGeneratorForm({ node }: INextOperatorForm) {
|
||||
resolver: zodResolver(FormSchema),
|
||||
});
|
||||
|
||||
const formOutputs = form.watch('outputs');
|
||||
|
||||
const outputList = useMemo(() => {
|
||||
return transferOutputs(values.outputs);
|
||||
}, [values.outputs]);
|
||||
return transferOutputs(formOutputs ?? values.outputs);
|
||||
}, [formOutputs, values.outputs]);
|
||||
|
||||
useWatchFormChange(node?.id, form);
|
||||
|
||||
|
||||
@ -48,6 +48,7 @@ import {
|
||||
initialVariableAssignerValues,
|
||||
initialWaitingDialogueValues,
|
||||
initialWenCaiValues,
|
||||
initialPDFGeneratorValues,
|
||||
initialWikipediaValues,
|
||||
initialYahooFinanceValues,
|
||||
} from '../constant';
|
||||
@ -179,7 +180,7 @@ export const useInitializeOperatorParams = () => {
|
||||
[Operator.Loop]: initialLoopValues,
|
||||
[Operator.LoopStart]: {},
|
||||
[Operator.ExitLoop]: {},
|
||||
[Operator.PDFGenerator]: {},
|
||||
[Operator.PDFGenerator]: initialPDFGeneratorValues,
|
||||
[Operator.ExcelProcessor]: {},
|
||||
};
|
||||
}, [llmId]);
|
||||
|
||||
@ -18,6 +18,7 @@ import {
|
||||
AgentVariableType,
|
||||
BeginId,
|
||||
BeginQueryType,
|
||||
BeginQueryTypeMap,
|
||||
JsonSchemaDataType,
|
||||
Operator,
|
||||
VariableType,
|
||||
@ -463,7 +464,14 @@ export function useGetVariableLabelOrTypeByValue({
|
||||
|
||||
const getType = useCallback(
|
||||
(val?: string) => {
|
||||
return getItem(val)?.type || findAgentStructuredOutputTypeByValue(val);
|
||||
const currentType =
|
||||
getItem(val)?.type || findAgentStructuredOutputTypeByValue(val);
|
||||
|
||||
if (currentType && currentType in BeginQueryTypeMap) {
|
||||
return BeginQueryTypeMap[currentType as BeginQueryType];
|
||||
}
|
||||
|
||||
return currentType;
|
||||
},
|
||||
[findAgentStructuredOutputTypeByValue, getItem],
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user