mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-16 04:17:49 +08:00
### What problem does this PR solve? This PR aims to extend the list of possible providers. Adds new Provider "RAGcon" within the Ollama Modal. It provides all model types except OCR via Openai-compatible endpoints. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Jakob <16180662+hauberj@users.noreply.github.com>
552 lines
19 KiB
Python
552 lines
19 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import json
|
|
from abc import ABC
|
|
from urllib.parse import urljoin
|
|
|
|
import httpx
|
|
import numpy as np
|
|
import requests
|
|
from yarl import URL
|
|
|
|
from common.log_utils import log_exception
|
|
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
|
|
|
class Base(ABC):
|
|
def __init__(self, key, model_name, **kwargs):
|
|
"""
|
|
Abstract base class constructor.
|
|
Parameters are not stored; initialization is left to subclasses.
|
|
"""
|
|
pass
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
raise NotImplementedError("Please implement encode method!")
|
|
|
|
@staticmethod
|
|
def _normalize_rank(rank: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Normalize rank values to the range 0 to 1.
|
|
Avoids division by zero if all ranks are identical.
|
|
"""
|
|
min_rank = np.min(rank)
|
|
max_rank = np.max(rank)
|
|
|
|
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
|
rank = (rank - min_rank) / (max_rank - min_rank)
|
|
else:
|
|
rank = np.zeros_like(rank)
|
|
|
|
return rank
|
|
|
|
|
|
class JinaRerank(Base):
|
|
_FACTORY_NAME = "Jina"
|
|
|
|
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
|
self.base_url = "https://api.jina.ai/v1/rerank"
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
texts = [truncate(t, 8196) for t in texts]
|
|
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, total_token_count_from_response(res)
|
|
|
|
|
|
class XInferenceRerank(Base):
|
|
_FACTORY_NAME = "Xinference"
|
|
|
|
def __init__(self, key="x", model_name="", base_url=""):
|
|
if base_url.find("/v1") == -1:
|
|
base_url = urljoin(base_url, "/v1/rerank")
|
|
if base_url.find("/rerank") == -1:
|
|
base_url = urljoin(base_url, "/v1/rerank")
|
|
self.model_name = model_name
|
|
self.base_url = base_url
|
|
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
|
if key and key != "x":
|
|
self.headers["Authorization"] = f"Bearer {key}"
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
if len(texts) == 0:
|
|
return np.array([]), 0
|
|
pairs = [(query, truncate(t, 4096)) for t in texts]
|
|
token_count = 0
|
|
for _, t in pairs:
|
|
token_count += num_tokens_from_string(t)
|
|
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, token_count
|
|
|
|
|
|
class LocalAIRerank(Base):
|
|
_FACTORY_NAME = "LocalAI"
|
|
|
|
def __init__(self, key, model_name, base_url):
|
|
if base_url.find("/rerank") == -1:
|
|
self.base_url = urljoin(base_url, "/rerank")
|
|
else:
|
|
self.base_url = base_url
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
# noway to config Ragflow , use fix setting
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
|
|
rank = Base._normalize_rank(rank)
|
|
|
|
return rank, token_count
|
|
|
|
|
|
class NvidiaRerank(Base):
|
|
_FACTORY_NAME = "NVIDIA"
|
|
|
|
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
|
if not base_url:
|
|
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
|
self.model_name = model_name
|
|
|
|
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
|
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
|
|
|
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
|
self.base_url = urljoin(base_url, "reranking")
|
|
self.model_name = "nv-rerank-qa-mistral-4b:1"
|
|
|
|
self.headers = {
|
|
"accept": "application/json",
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {key}",
|
|
}
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": {"text": query},
|
|
"passages": [{"text": text} for text in texts],
|
|
"truncate": "END",
|
|
"top_n": len(texts),
|
|
}
|
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["rankings"]:
|
|
rank[d["index"]] = d["logit"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, token_count
|
|
|
|
|
|
class LmStudioRerank(Base):
|
|
_FACTORY_NAME = "LM-Studio"
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
pass
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
raise NotImplementedError("The LmStudioRerank has not been implement")
|
|
|
|
|
|
class OpenAI_APIRerank(Base):
|
|
_FACTORY_NAME = "OpenAI-API-Compatible"
|
|
|
|
def __init__(self, key, model_name, base_url):
|
|
normalized_base_url = (base_url or "").strip()
|
|
if "/rerank" in normalized_base_url:
|
|
self.base_url = normalized_base_url.rstrip("/")
|
|
else:
|
|
self.base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/")
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
# noway to config Ragflow , use fix setting
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
|
|
rank = Base._normalize_rank(rank)
|
|
|
|
return rank, token_count
|
|
|
|
|
|
class CoHereRerank(Base):
|
|
_FACTORY_NAME = ["Cohere", "VLLM"]
|
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
from cohere import Client
|
|
|
|
# Only pass base_url if it's a non-empty string, otherwise use default Cohere API endpoint
|
|
client_kwargs = {"api_key": key}
|
|
if base_url and base_url.strip():
|
|
client_kwargs["base_url"] = base_url
|
|
self.client = Client(**client_kwargs)
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
|
res = self.client.rerank(
|
|
model=self.model_name,
|
|
query=query,
|
|
documents=texts,
|
|
top_n=len(texts),
|
|
return_documents=False,
|
|
)
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res.results:
|
|
rank[d.index] = d.relevance_score
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, token_count
|
|
|
|
|
|
class TogetherAIRerank(Base):
|
|
_FACTORY_NAME = "TogetherAI"
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
pass
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
raise NotImplementedError("The api has not been implement")
|
|
|
|
|
|
class SILICONFLOWRerank(Base):
|
|
_FACTORY_NAME = "SILICONFLOW"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
|
normalized_base_url = (base_url or "").strip()
|
|
if not normalized_base_url:
|
|
normalized_base_url = "https://api.siliconflow.cn/v1/rerank"
|
|
if "/rerank" not in normalized_base_url:
|
|
normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/")
|
|
self.model_name = model_name
|
|
self.base_url = normalized_base_url
|
|
self.headers = {
|
|
"accept": "application/json",
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {key}",
|
|
}
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
payload = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
"return_documents": False,
|
|
"max_chunks_per_doc": 1024,
|
|
"overlap_tokens": 80,
|
|
}
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in response["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, response)
|
|
return (
|
|
rank,
|
|
total_token_count_from_response(response),
|
|
)
|
|
|
|
|
|
class BaiduYiyanRerank(Base):
|
|
_FACTORY_NAME = "BaiduYiyan"
|
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
from qianfan.resources import Reranker
|
|
|
|
key = json.loads(key)
|
|
ak = key.get("yiyan_ak", "")
|
|
sk = key.get("yiyan_sk", "")
|
|
self.client = Reranker(ak=ak, sk=sk)
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
res = self.client.do(
|
|
model=self.model_name,
|
|
query=query,
|
|
documents=texts,
|
|
top_n=len(texts),
|
|
).body
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, total_token_count_from_response(res)
|
|
|
|
|
|
class VoyageRerank(Base):
|
|
_FACTORY_NAME = "Voyage AI"
|
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
import voyageai
|
|
|
|
self.client = voyageai.Client(api_key=key)
|
|
self.model_name = model_name
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
if not texts:
|
|
return np.array([]), 0
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
|
|
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
|
try:
|
|
for r in res.results:
|
|
rank[r.index] = r.relevance_score
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
return rank, res.total_tokens
|
|
|
|
|
|
class QWenRerank(Base):
|
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
|
|
|
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
|
|
import dashscope
|
|
|
|
self.api_key = key
|
|
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
from http import HTTPStatus
|
|
|
|
import dashscope
|
|
|
|
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
if resp.status_code == HTTPStatus.OK:
|
|
try:
|
|
for r in resp.output.results:
|
|
rank[r.index] = r.relevance_score
|
|
except Exception as _e:
|
|
log_exception(_e, resp)
|
|
return rank, total_token_count_from_response(resp)
|
|
else:
|
|
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
|
|
|
|
|
class HuggingfaceRerank(Base):
|
|
_FACTORY_NAME = "HuggingFace"
|
|
|
|
@staticmethod
|
|
def post(query: str, texts: list, url="127.0.0.1"):
|
|
exc = None
|
|
scores = [0 for _ in range(len(texts))]
|
|
batch_size = 8
|
|
for i in range(0, len(texts), batch_size):
|
|
try:
|
|
res = requests.post(
|
|
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
|
|
)
|
|
|
|
for o in res.json():
|
|
scores[o["index"] + i] = o["score"]
|
|
except Exception as e:
|
|
exc = e
|
|
|
|
if exc:
|
|
raise exc
|
|
return np.array(scores)
|
|
|
|
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
|
|
self.model_name = model_name.split("___")[0]
|
|
self.base_url = base_url
|
|
|
|
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
|
|
if not texts:
|
|
return np.array([]), 0
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
|
|
|
|
|
|
class GPUStackRerank(Base):
|
|
_FACTORY_NAME = "GPUStack"
|
|
|
|
def __init__(self, key, model_name, base_url):
|
|
if not base_url:
|
|
raise ValueError("url cannot be None")
|
|
|
|
self.model_name = model_name
|
|
self.base_url = str(URL(base_url) / "v1" / "rerank")
|
|
self.headers = {
|
|
"accept": "application/json",
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {key}",
|
|
}
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
payload = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
|
|
try:
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
|
response.raise_for_status()
|
|
response_json = response.json()
|
|
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
try:
|
|
for result in response_json["results"]:
|
|
rank[result["index"]] = result["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, response)
|
|
|
|
return (
|
|
rank,
|
|
token_count,
|
|
)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
|
|
|
|
|
class NovitaRerank(JinaRerank):
|
|
_FACTORY_NAME = "NovitaAI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
|
if not base_url:
|
|
base_url = "https://api.novita.ai/v3/openai/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
class GiteeRerank(JinaRerank):
|
|
_FACTORY_NAME = "GiteeAI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
|
if not base_url:
|
|
base_url = "https://ai.gitee.com/v1/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
class Ai302Rerank(Base):
|
|
_FACTORY_NAME = "302.AI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
|
|
if not base_url:
|
|
base_url = "https://api.302.ai/v1/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
class JiekouAIRerank(JinaRerank):
|
|
_FACTORY_NAME = "Jiekou.AI"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/rerank"):
|
|
if not base_url:
|
|
base_url = "https://api.jiekou.ai/openai/v1/rerank"
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
class RAGconRerank(Base):
|
|
"""
|
|
RAGcon Rerank Provider - routes through LiteLLM proxy
|
|
|
|
Assumes LiteLLM proxy supports /rerank endpoint.
|
|
Default Base URL: https://connect.ragcon.ai/v1
|
|
"""
|
|
_FACTORY_NAME = "RAGcon"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
if not base_url:
|
|
base_url = "https://connect.ragcon.com/v1"
|
|
|
|
self._api_key = key
|
|
self._base_url = base_url
|
|
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
|
self.model_name = model_name
|
|
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
# noway to config Ragflow , use fix setting
|
|
texts = [truncate(t, 500) for t in texts]
|
|
data = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": texts,
|
|
"top_n": len(texts),
|
|
}
|
|
token_count = 0
|
|
for t in texts:
|
|
token_count += num_tokens_from_string(t)
|
|
res = requests.post(self._base_url + "/rerank", headers=self.headers, json=data).json()
|
|
rank = np.zeros(len(texts), dtype=float)
|
|
try:
|
|
for d in res["results"]:
|
|
rank[d["index"]] = d["relevance_score"]
|
|
except Exception as _e:
|
|
log_exception(_e, res)
|
|
|
|
rank = Base._normalize_rank(rank)
|
|
|
|
return rank, token_count |