From 13d0df15627e08de372a00bfc55a8d3d9c833b60 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Fri, 20 Mar 2026 02:47:48 +0000 Subject: [PATCH] feat: add Perplexity contextualized embeddings API as a new model provider (#13709) ### What problem does this PR solve? Adds Perplexity contextualized embeddings API as a new model provider, as requested in #13610. - `PerplexityEmbed` provider in `rag/llm/embedding_model.py` supporting both standard (`/v1/embeddings`) and contextualized (`/v1/contextualizedembeddings`) endpoints - All 4 Perplexity embedding models registered in `conf/llm_factories.json`: `pplx-embed-v1-0.6b`, `pplx-embed-v1-4b`, `pplx-embed-context-v1-0.6b`, `pplx-embed-context-v1-4b` - Frontend entries (enum, icon mapping, API key URL) in `web/src/constants/llm.ts` - Updated `docs/guides/models/supported_models.mdx` - 22 unit tests in `test/unit_test/rag/llm/test_perplexity_embed.py` Perplexity's API returns `base64_int8` encoded embeddings (not OpenAI-compatible), so this uses a custom `requests`-based implementation. Contextualized vs standard model is auto-detected from the model name. Closes #13610 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update --- conf/llm_factories.json | 32 +++ docs/guides/models/supported_models.mdx | 1 + rag/llm/embedding_model.py | 109 ++++++-- test/unit_test/rag/llm/__init__.py | 0 test/unit_test/rag/llm/conftest.py | 61 +++++ .../rag/llm/test_perplexity_embed.py | 250 ++++++++++++++++++ web/src/constants/llm.ts | 4 + 7 files changed, 441 insertions(+), 16 deletions(-) create mode 100644 test/unit_test/rag/llm/__init__.py create mode 100644 test/unit_test/rag/llm/conftest.py create mode 100644 test/unit_test/rag/llm/test_perplexity_embed.py diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 69db0f796..3e6a1fce7 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -6317,6 +6317,38 @@ "status": "1", "rank": "100", "llm": [] + }, + { + "name": "Perplexity", + "logo": "", + "tags": "TEXT EMBEDDING", + "status": "1", + "llm": [ + { + "llm_name": "pplx-embed-v1-0.6b", + "tags": "TEXT EMBEDDING,32000", + "max_tokens": 32000, + "model_type": "embedding" + }, + { + "llm_name": "pplx-embed-v1-4b", + "tags": "TEXT EMBEDDING,32000", + "max_tokens": 32000, + "model_type": "embedding" + }, + { + "llm_name": "pplx-embed-context-v1-0.6b", + "tags": "TEXT EMBEDDING,32000", + "max_tokens": 32000, + "model_type": "embedding" + }, + { + "llm_name": "pplx-embed-context-v1-4b", + "tags": "TEXT EMBEDDING,32000", + "max_tokens": 32000, + "model_type": "embedding" + } + ] } ] } diff --git a/docs/guides/models/supported_models.mdx b/docs/guides/models/supported_models.mdx index 3dcfeed0d..7c1131028 100644 --- a/docs/guides/models/supported_models.mdx +++ b/docs/guides/models/supported_models.mdx @@ -46,6 +46,7 @@ A complete list of models supported by RAGFlow, which will continue to expand. | OpenAI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | OpenAI-API-Compatible | :heavy_check_mark: | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: | | | OpenRouter | :heavy_check_mark: | :heavy_check_mark: | | | | | | +| Perplexity | | :heavy_check_mark: | | | | | | | Replicate | :heavy_check_mark: | | | | :heavy_check_mark: | | | | PPIO | :heavy_check_mark: | | | | | | | | SILICONFLOW | :heavy_check_mark: | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: | | diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index f4b58619b..28ab2e262 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -113,7 +113,7 @@ class OpenAIEmbed(Base): return np.array(ress), total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True}) + res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) try: return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: @@ -358,7 +358,7 @@ class JinaMultiVecEmbed(Base): self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name - def encode(self, texts: list[str|bytes], task="retrieval.passage"): + def encode(self, texts: list[str | bytes], task="retrieval.passage"): batch_size = 16 ress = [] token_count = 0 @@ -370,9 +370,9 @@ class JinaMultiVecEmbed(Base): img_b64s = None try: base64.b64decode(text, validate=True) - img_b64s = text.decode('utf8') + img_b64s = text.decode("utf8") except Exception: - img_b64s = base64.b64encode(text).decode('utf8') + img_b64s = base64.b64encode(text).decode("utf8") input.append({"image": img_b64s}) # base64 encoded image for i in range(0, len(texts), batch_size): data = {"model": self.model_name, "input": input[i : i + batch_size]} @@ -380,20 +380,20 @@ class JinaMultiVecEmbed(Base): data["return_multivector"] = True if "v3" in self.model_name or "v4" in self.model_name: - data['task'] = task - data['truncate'] = True + data["task"] = task + data["truncate"] = True response = requests.post(self.base_url, headers=self.headers, json=data) try: res = response.json() - for d in res['data']: - if data.get("return_multivector", False): # v4 - token_embs = np.asarray(d['embeddings'], dtype=np.float32) + for d in res["data"]: + if data.get("return_multivector", False): # v4 + token_embs = np.asarray(d["embeddings"], dtype=np.float32) chunk_emb = token_embs.mean(axis=0) else: # v2/v3 - chunk_emb = np.asarray(d['embedding'], dtype=np.float32) + chunk_emb = np.asarray(d["embedding"], dtype=np.float32) ress.append(chunk_emb) @@ -444,6 +444,7 @@ class MistralEmbed(Base): def encode_queries(self, text): import time import random + retry_max = 5 while retry_max > 0: try: @@ -462,6 +463,7 @@ class BedrockEmbed(Base): def __init__(self, key, model_name, **kwargs): import boto3 + # `key` protocol (backend stores as JSON string in `api_key`): # - Must decode into a dict. # - Required: `auth_mode`, `bedrock_region`. @@ -497,10 +499,9 @@ class BedrockEmbed(Base): aws_secret_access_key=creds["SecretAccessKey"], aws_session_token=creds["SessionToken"], ) - else: # assume_role + else: # assume_role self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region) - def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] embeddings = [] @@ -1038,6 +1039,7 @@ class GiteeEmbed(SILICONFLOWEmbed): base_url = "https://ai.gitee.com/v1/embeddings" super().__init__(key, model_name, base_url) + class DeepInfraEmbed(OpenAIEmbed): _FACTORY_NAME = "DeepInfra" @@ -1064,6 +1066,7 @@ class CometAPIEmbed(OpenAIEmbed): base_url = "https://api.cometapi.com/v1" super().__init__(key, model_name, base_url) + class DeerAPIEmbed(OpenAIEmbed): _FACTORY_NAME = "DeerAPI" @@ -1081,16 +1084,90 @@ class JiekouAIEmbed(OpenAIEmbed): base_url = "https://api.jiekou.ai/openai/v1/embeddings" super().__init__(key, model_name, base_url) + class RAGconEmbed(OpenAIEmbed): """ RAGcon Embedding Provider - routes through LiteLLM proxy - + Default Base URL: https://connect.ragcon.ai/v1 """ + _FACTORY_NAME = "RAGcon" - + def __init__(self, key, model_name="text-embedding-3-small", base_url=None): if not base_url: base_url = "https://connect.ragcon.com/v1" - - super().__init__(key, model_name, base_url) \ No newline at end of file + + super().__init__(key, model_name, base_url) + + +class PerplexityEmbed(Base): + _FACTORY_NAME = "Perplexity" + + def __init__(self, key, model_name="pplx-embed-v1-0.6b", base_url="https://api.perplexity.ai"): + if not base_url: + base_url = "https://api.perplexity.ai" + self.base_url = base_url.rstrip("/") + self.api_key = key + self.model_name = model_name + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + @staticmethod + def _decode_base64_int8(b64_str): + raw = base64.b64decode(b64_str) + return np.frombuffer(raw, dtype=np.int8).astype(np.float32) + + def _is_contextualized(self): + return "context" in self.model_name + + def encode(self, texts: list): + batch_size = 512 + ress = [] + token_count = 0 + + if self._is_contextualized(): + url = f"{self.base_url}/v1/contextualizedembeddings" + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + payload = { + "model": self.model_name, + "input": [[chunk] for chunk in batch], + "encoding_format": "base64_int8", + } + response = requests.post(url, headers=self.headers, json=payload) + try: + res = response.json() + for doc in res["data"]: + for chunk_emb in doc["data"]: + ress.append(self._decode_base64_int8(chunk_emb["embedding"])) + token_count += res.get("usage", {}).get("total_tokens", 0) + except Exception as _e: + log_exception(_e, response) + raise Exception(f"Error: {response.text}") + else: + url = f"{self.base_url}/v1/embeddings" + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + payload = { + "model": self.model_name, + "input": batch, + "encoding_format": "base64_int8", + } + response = requests.post(url, headers=self.headers, json=payload) + try: + res = response.json() + for d in res["data"]: + ress.append(self._decode_base64_int8(d["embedding"])) + token_count += res.get("usage", {}).get("total_tokens", 0) + except Exception as _e: + log_exception(_e, response) + raise Exception(f"Error: {response.text}") + + return np.array(ress), token_count + + def encode_queries(self, text): + embds, cnt = self.encode([text]) + return np.array(embds[0]), cnt diff --git a/test/unit_test/rag/llm/__init__.py b/test/unit_test/rag/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/unit_test/rag/llm/conftest.py b/test/unit_test/rag/llm/conftest.py new file mode 100644 index 000000000..3d9bf31ca --- /dev/null +++ b/test/unit_test/rag/llm/conftest.py @@ -0,0 +1,61 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Prevent rag.llm.__init__ from running its heavy auto-discovery loop. + +The __init__.py dynamically imports ALL model modules (chat_model, +cv_model, ocr_model, etc.), which pull in deepdoc, xgboost, torch, +and other heavy native deps. We pre-install a lightweight stub for +the rag.llm package so that `from rag.llm.embedding_model import X` +works without triggering the full init. +""" + +import os +import sys +import types + +# Resolve the real path to rag/llm/ so sub-module imports can find files +_RAGFLOW_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) +_RAG_LLM_DIR = os.path.join(_RAGFLOW_ROOT, "rag", "llm") + + +def _install_rag_llm_stub(): + """Replace rag.llm with a minimal package stub if not yet loaded. + + The stub has __path__ pointing to the real rag/llm/ directory so that + `from rag.llm.embedding_model import X` resolves to the actual file, + but the __init__.py auto-discovery loop is skipped. + """ + if "rag.llm" in sys.modules: + return + + # Create a stub rag.llm package that does NOT run the real __init__ + llm_pkg = types.ModuleType("rag.llm") + llm_pkg.__path__ = [_RAG_LLM_DIR] + llm_pkg.__package__ = "rag.llm" + # Provide empty dicts for the mappings the real __init__ would build + llm_pkg.EmbeddingModel = {} + llm_pkg.ChatModel = {} + llm_pkg.CvModel = {} + llm_pkg.RerankModel = {} + llm_pkg.Seq2txtModel = {} + llm_pkg.TTSModel = {} + llm_pkg.OcrModel = {} + sys.modules["rag.llm"] = llm_pkg + + +_install_rag_llm_stub() diff --git a/test/unit_test/rag/llm/test_perplexity_embed.py b/test/unit_test/rag/llm/test_perplexity_embed.py new file mode 100644 index 000000000..9edef6736 --- /dev/null +++ b/test/unit_test/rag/llm/test_perplexity_embed.py @@ -0,0 +1,250 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest + +from rag.llm.embedding_model import PerplexityEmbed + + +def _make_b64_int8(values): + """Helper: encode a list of int8 values to base64 string.""" + arr = np.array(values, dtype=np.int8) + return base64.b64encode(arr.tobytes()).decode() + + +def _mock_standard_response(embeddings_b64, total_tokens=10): + """Build a mock JSON response for the standard embeddings endpoint.""" + return { + "object": "list", + "data": [{"object": "embedding", "index": i, "embedding": emb} for i, emb in enumerate(embeddings_b64)], + "model": "pplx-embed-v1-0.6b", + "usage": {"total_tokens": total_tokens}, + } + + +def _mock_contextualized_response(docs_embeddings_b64, total_tokens=20): + """Build a mock JSON response for the contextualized embeddings endpoint.""" + data = [] + for doc_idx, chunks in enumerate(docs_embeddings_b64): + data.append( + { + "index": doc_idx, + "data": [{"object": "embedding", "index": chunk_idx, "embedding": emb} for chunk_idx, emb in enumerate(chunks)], + } + ) + return { + "object": "list", + "data": data, + "model": "pplx-embed-context-v1-0.6b", + "usage": {"total_tokens": total_tokens}, + } + + +class TestPerplexityEmbedInit: + def test_default_base_url(self): + embed = PerplexityEmbed("test-key", "pplx-embed-v1-0.6b") + assert embed.base_url == "https://api.perplexity.ai" + assert embed.api_key == "test-key" + assert embed.model_name == "pplx-embed-v1-0.6b" + + def test_custom_base_url(self): + embed = PerplexityEmbed("key", "pplx-embed-v1-4b", base_url="https://custom.api.com/") + assert embed.base_url == "https://custom.api.com" + + def test_empty_base_url_uses_default(self): + embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b", base_url="") + assert embed.base_url == "https://api.perplexity.ai" + + def test_auth_header(self): + embed = PerplexityEmbed("my-secret-key", "pplx-embed-v1-0.6b") + assert embed.headers["Authorization"] == "Bearer my-secret-key" + + +class TestPerplexityEmbedModelDetection: + def test_standard_model_not_contextualized(self): + embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") + assert not embed._is_contextualized() + + def test_standard_4b_not_contextualized(self): + embed = PerplexityEmbed("key", "pplx-embed-v1-4b") + assert not embed._is_contextualized() + + def test_contextualized_0_6b(self): + embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b") + assert embed._is_contextualized() + + def test_contextualized_4b(self): + embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b") + assert embed._is_contextualized() + + +class TestDecodeBase64Int8: + def test_basic_decode(self): + values = [-1, 0, 1, 127] + b64 = _make_b64_int8(values) + result = PerplexityEmbed._decode_base64_int8(b64) + expected = np.array(values, dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_empty_decode(self): + b64 = base64.b64encode(b"").decode() + result = PerplexityEmbed._decode_base64_int8(b64) + assert len(result) == 0 + + def test_full_range(self): + values = list(range(-128, 128)) + b64 = _make_b64_int8(values) + result = PerplexityEmbed._decode_base64_int8(b64) + expected = np.array(values, dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_output_dtype_is_float32(self): + b64 = _make_b64_int8([1, 2, 3]) + result = PerplexityEmbed._decode_base64_int8(b64) + assert result.dtype == np.float32 + + +class TestPerplexityEmbedStandardEncode: + @patch("rag.llm.embedding_model.requests.post") + def test_encode_single_text(self, mock_post): + emb_b64 = _make_b64_int8([10, 20, 30]) + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_standard_response([emb_b64], total_tokens=5) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") + result, tokens = embed.encode(["hello"]) + + assert result.shape == (1, 3) + np.testing.assert_array_equal(result[0], np.array([10, 20, 30], dtype=np.float32)) + assert tokens == 5 + mock_post.assert_called_once() + call_url = mock_post.call_args[0][0] + assert call_url == "https://api.perplexity.ai/v1/embeddings" + + @patch("rag.llm.embedding_model.requests.post") + def test_encode_multiple_texts(self, mock_post): + emb1 = _make_b64_int8([1, 2]) + emb2 = _make_b64_int8([3, 4]) + emb3 = _make_b64_int8([5, 6]) + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_standard_response([emb1, emb2, emb3], total_tokens=15) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") + result, tokens = embed.encode(["a", "b", "c"]) + + assert result.shape == (3, 2) + assert tokens == 15 + + @patch("rag.llm.embedding_model.requests.post") + def test_encode_sends_correct_payload(self, mock_post): + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_standard_response([_make_b64_int8([1])], total_tokens=1) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-v1-4b") + embed.encode(["test text"]) + + call_kwargs = mock_post.call_args + payload = call_kwargs[1]["json"] + assert payload["model"] == "pplx-embed-v1-4b" + assert payload["input"] == ["test text"] + assert payload["encoding_format"] == "base64_int8" + + @patch("rag.llm.embedding_model.requests.post") + def test_encode_api_error_raises(self, mock_post): + mock_resp = MagicMock() + mock_resp.json.side_effect = Exception("Invalid JSON") + mock_resp.text = "Internal Server Error" + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") + with pytest.raises(Exception, match="Error"): + embed.encode(["hello"]) + + +class TestPerplexityEmbedContextualizedEncode: + @patch("rag.llm.embedding_model.requests.post") + def test_contextualized_encode(self, mock_post): + emb1 = _make_b64_int8([10, 20]) + emb2 = _make_b64_int8([30, 40]) + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_contextualized_response([[emb1], [emb2]], total_tokens=12) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b") + result, tokens = embed.encode(["chunk1", "chunk2"]) + + assert result.shape == (2, 2) + np.testing.assert_array_equal(result[0], np.array([10, 20], dtype=np.float32)) + np.testing.assert_array_equal(result[1], np.array([30, 40], dtype=np.float32)) + assert tokens == 12 + + @patch("rag.llm.embedding_model.requests.post") + def test_contextualized_uses_correct_endpoint(self, mock_post): + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b") + embed.encode(["chunk"]) + + call_url = mock_post.call_args[0][0] + assert call_url == "https://api.perplexity.ai/v1/contextualizedembeddings" + + @patch("rag.llm.embedding_model.requests.post") + def test_contextualized_sends_nested_input(self, mock_post): + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b") + embed.encode(["text1"]) + + payload = mock_post.call_args[1]["json"] + assert payload["input"] == [["text1"]] + assert payload["model"] == "pplx-embed-context-v1-0.6b" + + +class TestPerplexityEmbedEncodeQueries: + @patch("rag.llm.embedding_model.requests.post") + def test_encode_queries_returns_single_vector(self, mock_post): + emb = _make_b64_int8([5, 10, 15, 20]) + mock_resp = MagicMock() + mock_resp.json.return_value = _mock_standard_response([emb], total_tokens=3) + mock_post.return_value = mock_resp + + embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") + result, tokens = embed.encode_queries("search query") + + assert result.shape == (4,) + np.testing.assert_array_equal(result, np.array([5, 10, 15, 20], dtype=np.float32)) + assert tokens == 3 + + +class TestPerplexityEmbedFactoryRegistration: + def test_factory_name(self): + assert PerplexityEmbed._FACTORY_NAME == "Perplexity" + + def test_is_subclass_of_base(self): + from rag.llm.embedding_model import Base + + assert issubclass(PerplexityEmbed, Base) diff --git a/web/src/constants/llm.ts b/web/src/constants/llm.ts index e3f1dffba..52c1a1d7d 100644 --- a/web/src/constants/llm.ts +++ b/web/src/constants/llm.ts @@ -65,6 +65,7 @@ export enum LLMFactory { N1n = 'n1n', Avian = 'Avian', RAGcon = 'RAGcon', + Perplexity = 'Perplexity', } // Please lowercase the file name @@ -135,6 +136,7 @@ export const IconMap = { [LLMFactory.N1n]: 'n1n', [LLMFactory.Avian]: 'avian', [LLMFactory.RAGcon]: 'ragcon', + [LLMFactory.Perplexity]: 'perplexity', }; export const APIMapUrl = { @@ -188,4 +190,6 @@ export const APIMapUrl = { [LLMFactory.PaddleOCR]: 'https://www.paddleocr.ai/latest/', [LLMFactory.N1n]: 'https://docs.n1n.ai', [LLMFactory.Avian]: 'https://avian.io', + [LLMFactory.Perplexity]: + 'https://docs.perplexity.ai/docs/embeddings/quickstart', };