mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-01 05:17:51 +08:00
## Summary This PR fully addresses all CodeRabbit review feedback and enhances the robustness of the reranking module with 100% backward compatibility. ## Key Fixes 1. Fixed JinaRerank hardcoded base_url to support subclass endpoint overrides 2. Corrected GPUStackRerank exception handling to use proper requests exceptions and preserve stack traces 3. Added 30s timeout to all API calls to prevent service hanging 4. Added empty input validation for all rerank providers 5. Replaced direct dict key access with .get() to eliminate KeyError crashes 6. Fixed _normalize_rank edge case for empty arrays 7. Implemented missing functionality for Ai302Rerank 8. Standardized type hints and fixed typo issues ## Compatibility - No breaking changes to any existing functionality - All rerank providers work as originally intended - Fully compatible with existing configurations and workflows ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
611 lines
22 KiB
Python
611 lines
22 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import json
|
|
import logging
|
|
from abc import ABC
|
|
from urllib.parse import urljoin
|
|
from typing import Tuple, List
|
|
from http import HTTPStatus
|
|
|
|
import numpy as np
|
|
import requests
|
|
from yarl import URL
|
|
|
|
from common.log_utils import log_exception
|
|
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
|
|
|
class Base(ABC):
|
|
def __init__(self, key, model_name, **kwargs):
|
|
pass
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
raise NotImplementedError("Please implement encode method!")
|
|
|
|
@staticmethod
|
|
def _normalize_rank(rank: np.ndarray) -> np.ndarray:
|
|
if rank.size == 0:
|
|
return rank
|
|
min_rank = np.min(rank)
|
|
max_rank = np.max(rank)
|
|
|
|
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
|
rank = (rank - min_rank) / (max_rank - min_rank)
|
|
else:
|
|
rank = np.zeros_like(rank)
|
|
|
|
return rank
|
|
|
|
|
|
class JinaRerank(Base):
|
|
_FACTORY_NAME = "Jina"
|
|
|
|
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
|
self.base_url = base_url or "https://api.jina.ai/v1/rerank"
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts) if texts else 0, dtype=float), 0
|
|
texts = [truncate(t, 8196) for t in texts]
|
|
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, total_token_count_from_response(res)
|
|
|
|
|
|
class XInferenceRerank(Base):
|
|
_FACTORY_NAME = "Xinference"
|
|
|
|
def __init__(self, key="x", model_name="", base_url=""):
|
|
if base_url.find("/v1") == -1:
|
|
base_url = urljoin(base_url, "/v1/rerank")
|
|
if base_url.find("/rerank") == -1:
|
|
base_url = urljoin(base_url, "/v1/rerank")
|
|
self.model_name = model_name
|
|
self.base_url = base_url
|
|
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
|
if key and key != "x":
|
|
self.headers["Authorization"] = f"Bearer {key}"
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts) if texts else 0, dtype=float), 0
|
|
pairs = [(query, truncate(t, 4096)) for t in texts]
|
|
token_count = 0
|
|
for _, t in pairs:
|
|
token_count += num_tokens_from_string(t)
|
|
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, token_count
|
|
|
|
|
|
class LocalAIRerank(Base):
|
|
_FACTORY_NAME = "LocalAI"
|
|
|
|
def __init__(self, key, model_name, base_url):
|
|
if base_url.find("/rerank") == -1:
|
|
self.base_url = urljoin(base_url, "/rerank")
|
|
else:
|
|
self.base_url = base_url
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
|
|
rank = Base._normalize_rank(rank)
|
|
return rank, token_count
|
|
|
|
|
|
class NvidiaRerank(Base):
|
|
_FACTORY_NAME = "NVIDIA"
|
|
|
|
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
|
if not base_url:
|
|
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
|
self.model_name = model_name
|
|
|
|
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
|
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
|
|
|
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
|
self.base_url = urljoin(base_url, "reranking")
|
|
self.model_name = "nv-rerank-qa-mistral-4b:1"
|
|
|
|
self.headers = {
|
|
"accept": "application/json",
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {key}",
|
|
}
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": {"text": query},
|
|
"passages": [{"text": text} for text in texts],
|
|
"truncate": "END",
|
|
"top_n": len(texts),
|
|
}
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("rankings", []):
|
|
rank[d["index"]] = d["logit"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, token_count
|
|
|
|
|
|
class LmStudioRerank(Base):
|
|
_FACTORY_NAME = "LM-Studio"
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
pass
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
raise NotImplementedError("The LmStudioRerank has not been implemented")
|
|
|
|
|
|
class OpenAI_APIRerank(Base):
|
|
_FACTORY_NAME = "OpenAI-API-Compatible"
|
|
|
|
def __init__(self, key, model_name, base_url):
|
|
normalized_base_url = (base_url or "").strip()
|
|
if "/rerank" in normalized_base_url:
|
|
self.base_url = normalized_base_url.rstrip("/")
|
|
else:
|
|
self.base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/")
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
|
|
rank = Base._normalize_rank(rank)
|
|
return rank, token_count
|
|
|
|
|
|
class CoHereRerank(Base):
|
|
_FACTORY_NAME = ["Cohere", "VLLM"]
|
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
from cohere import Client
|
|
|
|
client_kwargs = {"api_key": key, "timeout": 30.0}
|
|
if base_url and base_url.strip():
|
|
client_kwargs["base_url"] = base_url
|
|
self.client = Client(**client_kwargs)
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
|
res = self.client.rerank(
|
|
model=self.model_name,
|
|
query=query,
|
|
documents=texts,
|
|
top_n=len(texts),
|
|
return_documents=False,
|
|
)
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.results:
|
|
rank[d.index] = d.relevance_score
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, token_count
|
|
|
|
|
|
class TogetherAIRerank(Base):
|
|
_FACTORY_NAME = "TogetherAI"
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
pass
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
raise NotImplementedError("The api has not been implemented")
|
|
|
|
|
|
class SILICONFLOWRerank(Base):
|
|
_FACTORY_NAME = "SILICONFLOW"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
|
normalized_base_url = (base_url or "").strip()
|
|
if not normalized_base_url:
|
|
normalized_base_url = "https://api.siliconflow.cn/v1/rerank"
|
|
if "/rerank" not in normalized_base_url:
|
|
normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/")
|
|
self.model_name = model_name
|
|
self.base_url = normalized_base_url
|
|
self.headers = {
|
|
"accept": "application/json",
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {key}",
|
|
}
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
payload = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
"return_documents": False,
|
|
"max_chunks_per_doc": 1024,
|
|
"overlap_tokens": 80,
|
|
}
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, response)
|
|
return rank, total_token_count_from_response(res)
|
|
|
|
|
|
class BaiduYiyanRerank(Base):
|
|
_FACTORY_NAME = "BaiduYiyan"
|
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
from qianfan.resources import Reranker
|
|
|
|
key = json.loads(key)
|
|
ak = key.get("yiyan_ak", "")
|
|
sk = key.get("yiyan_sk", "")
|
|
self.client = Reranker(ak=ak, sk=sk, request_timeout=30)
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
res = self.client.do(
|
|
model=self.model_name,
|
|
query=query,
|
|
documents=texts,
|
|
top_n=len(texts),
|
|
).body
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, total_token_count_from_response(res)
|
|
|
|
|
|
class VoyageRerank(Base):
|
|
_FACTORY_NAME = "Voyage AI"
|
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
import voyageai
|
|
|
|
self.client = voyageai.Client(api_key=key, timeout=30.0)
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts) if texts else 0, dtype=float), 0
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
|
|
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
|
try:
|
|
for r in res.results:
|
|
rank[r.index] = r.relevance_score
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, res.total_tokens
|
|
|
|
|
|
class QWenRerank(Base):
|
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
|
|
|
def __init__(self, key, model_name="gte-rerank", **kwargs):
|
|
import dashscope
|
|
self.api_key = key
|
|
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
|
# Remove invalid global timeout, use official SDK per-request timeout parameter
|
|
self.request_timeout = 30.0
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
|
|
import dashscope
|
|
|
|
# Pass official request_timeout parameter to both API call branches
|
|
if self.model_name.startswith("qwen3-rerank"):
|
|
resp = dashscope.TextReRank.call(
|
|
api_key=self.api_key, model=self.model_name,
|
|
query=query, documents=texts, top_n=len(texts),
|
|
request_timeout=self.request_timeout
|
|
)
|
|
else:
|
|
resp = dashscope.TextReRank.call(
|
|
api_key=self.api_key, model=self.model_name,
|
|
query=query, documents=texts,
|
|
top_n=len(texts), return_documents=False,
|
|
request_timeout=self.request_timeout
|
|
)
|
|
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
if resp.status_code == HTTPStatus.OK:
|
|
try:
|
|
for r in resp.output.results:
|
|
rank[r.index] = r.relevance_score
|
|
except Exception as _e:
|
|
log_exception(_e, resp)
|
|
return rank, total_token_count_from_response(resp)
|
|
else:
|
|
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
|
|
|
|
|
class HuggingfaceRerank(Base):
|
|
_FACTORY_NAME = "HuggingFace"
|
|
|
|
@staticmethod
|
|
def post(query: str, texts: list, url: str = "http://127.0.0.1"):
|
|
exc = None
|
|
scores = [0 for _ in range(len(texts))]
|
|
batch_size = 8
|
|
# FIX: Robust URL construction to avoid duplicate "/rerank" path suffix
|
|
base_url = url.rstrip("/")
|
|
if not base_url.startswith(("http://", "https://")):
|
|
base_url = f"http://{base_url}"
|
|
# Only append "/rerank" when endpoint does not already end with it
|
|
endpoint = base_url if base_url.endswith("/rerank") else f"{base_url}/rerank"
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
try:
|
|
res = requests.post(
|
|
endpoint, headers={"Content-Type": "application/json"},
|
|
json={"query": query, "texts": texts[i:i+batch_size], "raw_scores": False, "truncate": True},
|
|
timeout=30
|
|
)
|
|
res.raise_for_status()
|
|
for o in res.json():
|
|
scores[o["index"] + i] = o["score"]
|
|
except Exception as e:
|
|
exc = e
|
|
|
|
if exc:
|
|
raise exc
|
|
return np.array(scores)
|
|
|
|
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
|
|
self.model_name = model_name.split("___")[0]
|
|
self.base_url = base_url
|
|
|
|
def similarity(self, query: str, texts: List) -> tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
|
|
|
|
|
|
class GPUStackRerank(Base):
|
|
_FACTORY_NAME = "GPUStack"
|
|
|
|
def __init__(self, key, model_name, base_url):
|
|
if not base_url:
|
|
raise ValueError("url cannot be None")
|
|
|
|
self.model_name = model_name
|
|
self.base_url = str(URL(base_url) / "v1" / "rerank")
|
|
self.headers = {
|
|
"accept": "application/json",
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {key}",
|
|
}
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
|
|
payload = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
|
|
try:
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
|
response.raise_for_status()
|
|
response_json = response.json()
|
|
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
token_count = sum(num_tokens_from_string(t) for t in texts)
|
|
try:
|
|
for result in response_json.get("results", []):
|
|
rank[result["index"]] = result["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, response)
|
|
|
|
return (rank, token_count)
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {str(e)}") from e
|
|
|
|
|
|
class NovitaRerank(JinaRerank):
|
|
_FACTORY_NAME = "NovitaAI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
|
if not base_url:
|
|
base_url = "https://api.novita.ai/v3/openai/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
class GiteeRerank(JinaRerank):
|
|
_FACTORY_NAME = "GiteeAI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
|
if not base_url:
|
|
base_url = "https://ai.gitee.com/v1/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
class Ai302Rerank(Base):
|
|
_FACTORY_NAME = "302.AI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
|
|
self.base_url = base_url or "https://api.302.ai/v1/rerank"
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
|
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, total_token_count_from_response(res)
|
|
|
|
|
|
class JiekouAIRerank(JinaRerank):
|
|
_FACTORY_NAME = "Jiekou.AI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/rerank"):
|
|
if not base_url:
|
|
base_url = "https://api.jiekou.ai/openai/v1/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
class FuturMixRerank(OpenAI_APIRerank):
|
|
_FACTORY_NAME = "FuturMix"
|
|
|
|
def __init__(self, key, model_name, base_url="https://futurmix.ai/v1/rerank"):
|
|
if not base_url:
|
|
base_url = "https://futurmix.ai/v1/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
logging.info("[FuturMix] Rerank initialized with model %s", model_name)
|
|
|
|
|
|
class RAGconRerank(Base):
|
|
_FACTORY_NAME = "RAGcon"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
if not base_url:
|
|
base_url = "https://connect.ragcon.com/v1"
|
|
|
|
self._api_key = key
|
|
self._base_url = base_url
|
|
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name
|
|
|
|
|
|
def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]:
|
|
if not query or not texts:
|
|
return np.zeros(len(texts), dtype=float), 0
|
|
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
token_count = sum(num_tokens_from_string(t) for t in texts)
|
|
response = requests.post(self._base_url + "/rerank", headers=self.headers, json=data, timeout=30)
|
|
response.raise_for_status()
|
|
res = response.json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.get("results", []):
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
|
|
rank = Base._normalize_rank(rank)
|
|
return rank, token_count
|