Files
ragflow/rag/llm/rerank_model.py
07heco 8dc5b1b42d fix: optimize reranking module robustness and bug fixes (#14264)
## Description
This PR fixes critical bugs and improves the robustness of the RAG
reranking module while maintaining **100% backward compatibility** with
all existing functionality and providers.

## Key Changes
1. **Network Stability**: Added 30s timeout to all API requests to
prevent service blocking
2. **Boundary Protection**: Added empty query/text validation for all
rerank models
3. **Response Fault Tolerance**: Replaced hardcoded key access with
`.get()` to avoid KeyError crashes
4. **Bug Fixes**:
   - Fixed `Ai302Rerank` (completely non-functional before)
   - Fixed `GPUStackRerank` incorrect exception catching
   - Fixed `_normalize_rank` empty array crash
5. **Code Specification**: Added type annotations, standardized
unimplemented class prompts

## Compatibility
-  No changes to any class/method names
-  All rerank providers (Jina/Cohere/NVIDIA/HuggingFace etc.) work as
before
-  No breaking changes, zero impact on existing workflows

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-05-14 11:56:09 +08:00

612 lines
22 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
from abc import ABC
from urllib.parse import urljoin
from typing import Tuple, List
from http import HTTPStatus
import numpy as np
import requests
from yarl import URL
from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
pass
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
raise NotImplementedError("Please implement encode method!")
@staticmethod
def _normalize_rank(rank: np.ndarray) -> np.ndarray:
if rank.size == 0:
return rank
min_rank = np.min(rank)
max_rank = np.max(rank)
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank
class JinaRerank(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = base_url or "https://api.jina.ai/v1/rerank"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts) if texts else 0, dtype=float), 0
texts = [truncate(t, 8196) for t in texts]
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, total_token_count_from_response(res)
class XInferenceRerank(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank")
if base_url.find("/rerank") == -1:
base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts) if texts else 0, dtype=float), 0
pairs = [(query, truncate(t, 4096)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LocalAIRerank(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
rank = Base._normalize_rank(rank)
return rank, token_count
class NvidiaRerank(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking")
self.model_name = "nv-rerank-qa-mistral-4b:1"
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
data = {
"model": self.model_name,
"query": {"text": query},
"passages": [{"text": text} for text in texts],
"truncate": "END",
"top_n": len(texts),
}
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("rankings", []):
rank[d["index"]] = d["logit"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LmStudioRerank(Base):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
raise NotImplementedError("The LmStudioRerank has not been implemented")
class OpenAI_APIRerank(Base):
_FACTORY_NAME = "OpenAI-API-Compatible"
def __init__(self, key, model_name, base_url):
normalized_base_url = (base_url or "").strip()
if "/rerank" in normalized_base_url:
self.base_url = normalized_base_url.rstrip("/")
else:
self.base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/")
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
rank = Base._normalize_rank(rank)
return rank, token_count
class CoHereRerank(Base):
_FACTORY_NAME = ["Cohere", "VLLM"]
def __init__(self, key, model_name, base_url=None):
from cohere import Client
client_kwargs = {"api_key": key, "timeout": 30.0}
if base_url and base_url.strip():
client_kwargs["base_url"] = base_url
self.client = Client(**client_kwargs)
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
res = self.client.rerank(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
return_documents=False,
)
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.results:
rank[d.index] = d.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class TogetherAIRerank(Base):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
raise NotImplementedError("The api has not been implemented")
class SILICONFLOWRerank(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
normalized_base_url = (base_url or "").strip()
if not normalized_base_url:
normalized_base_url = "https://api.siliconflow.cn/v1/rerank"
if "/rerank" not in normalized_base_url:
normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/")
self.model_name = model_name
self.base_url = normalized_base_url
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
"return_documents": False,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return rank, total_token_count_from_response(res)
class BaiduYiyanRerank(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = Reranker(ak=ak, sk=sk, request_timeout=30)
self.model_name = model_name
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
res = self.client.do(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
).body
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, total_token_count_from_response(res)
class VoyageRerank(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key, timeout=30.0)
self.model_name = model_name
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts) if texts else 0, dtype=float), 0
rank = np.zeros(len(texts), dtype=float)
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
try:
for r in res.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, res.total_tokens
class QWenRerank(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="gte-rerank", **kwargs):
import dashscope
self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
# Remove invalid global timeout, use official SDK per-request timeout parameter
self.request_timeout = 30.0
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
import dashscope
# Pass official request_timeout parameter to both API call branches
if self.model_name.startswith("qwen3-rerank"):
resp = dashscope.TextReRank.call(
api_key=self.api_key, model=self.model_name,
query=query, documents=texts, top_n=len(texts),
request_timeout=self.request_timeout
)
else:
resp = dashscope.TextReRank.call(
api_key=self.api_key, model=self.model_name,
query=query, documents=texts,
top_n=len(texts), return_documents=False,
request_timeout=self.request_timeout
)
rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK:
try:
for r in resp.output.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, resp)
return rank, total_token_count_from_response(resp)
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
class HuggingfaceRerank(Base):
_FACTORY_NAME = "HuggingFace"
@staticmethod
def post(query: str, texts: list, url: str = "http://127.0.0.1"):
exc = None
scores = [0 for _ in range(len(texts))]
batch_size = 8
# FIX: Robust URL construction to avoid duplicate "/rerank" path suffix
base_url = url.rstrip("/")
if not base_url.startswith(("http://", "https://")):
base_url = f"http://{base_url}"
# Only append "/rerank" when endpoint does not already end with it
endpoint = base_url if base_url.endswith("/rerank") else f"{base_url}/rerank"
for i in range(0, len(texts), batch_size):
try:
# Fix: Add request timeout
res = requests.post(
endpoint, headers={"Content-Type": "application/json"},
json={"query": query, "texts": texts[i:i+batch_size], "raw_scores": False, "truncate": True},
timeout=30
)
res.raise_for_status()
for o in res.json():
scores[o["index"] + i] = o["score"]
except Exception as e:
exc = e
if exc:
raise exc
return np.array(scores)
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
self.model_name = model_name.split("___")[0]
self.base_url = base_url
def similarity(self, query: str, texts: List) -> tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
class GPUStackRerank(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
self.model_name = model_name
self.base_url = str(URL(base_url) / "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
try:
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
response.raise_for_status()
response_json = response.json()
rank = np.zeros(len(texts), dtype=float)
token_count = sum(num_tokens_from_string(t) for t in texts)
try:
for result in response_json.get("results", []):
rank[result["index"]] = result["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (rank, token_count)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {str(e)}") from e
class NovitaRerank(JinaRerank):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank"
super().__init__(key, model_name, base_url)
class GiteeRerank(JinaRerank):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url:
base_url = "https://ai.gitee.com/v1/rerank"
super().__init__(key, model_name, base_url)
class Ai302Rerank(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
self.base_url = base_url or "https://api.302.ai/v1/rerank"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
texts = [truncate(t, 500) for t in texts]
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, total_token_count_from_response(res)
class JiekouAIRerank(JinaRerank):
_FACTORY_NAME = "Jiekou.AI"
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/rerank"):
if not base_url:
base_url = "https://api.jiekou.ai/openai/v1/rerank"
super().__init__(key, model_name, base_url)
class FuturMixRerank(OpenAI_APIRerank):
_FACTORY_NAME = "FuturMix"
def __init__(self, key, model_name, base_url="https://futurmix.ai/v1/rerank"):
if not base_url:
base_url = "https://futurmix.ai/v1/rerank"
super().__init__(key, model_name, base_url)
logging.info("[FuturMix] Rerank initialized with model %s", model_name)
class RAGconRerank(Base):
_FACTORY_NAME = "RAGcon"
def __init__(self, key, model_name, base_url=None, **kwargs):
if not base_url:
base_url = "https://connect.ragcon.com/v1"
self._api_key = key
self._base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
if not query or not texts:
return np.zeros(len(texts), dtype=float), 0
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = sum(num_tokens_from_string(t) for t in texts)
response = requests.post(self._base_url + "/rerank", headers=self.headers, json=data, timeout=30)
response.raise_for_status()
res = response.json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.get("results", []):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
rank = Base._normalize_rank(rank)
return rank, token_count