mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-27 09:19:39 +08:00
feat: add Perplexity contextualized embeddings API as a new model provider (#13709)
### What problem does this PR solve? Adds Perplexity contextualized embeddings API as a new model provider, as requested in #13610. - `PerplexityEmbed` provider in `rag/llm/embedding_model.py` supporting both standard (`/v1/embeddings`) and contextualized (`/v1/contextualizedembeddings`) endpoints - All 4 Perplexity embedding models registered in `conf/llm_factories.json`: `pplx-embed-v1-0.6b`, `pplx-embed-v1-4b`, `pplx-embed-context-v1-0.6b`, `pplx-embed-context-v1-4b` - Frontend entries (enum, icon mapping, API key URL) in `web/src/constants/llm.ts` - Updated `docs/guides/models/supported_models.mdx` - 22 unit tests in `test/unit_test/rag/llm/test_perplexity_embed.py` Perplexity's API returns `base64_int8` encoded embeddings (not OpenAI-compatible), so this uses a custom `requests`-based implementation. Contextualized vs standard model is auto-detected from the model name. Closes #13610 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update
This commit is contained in:
@ -6317,6 +6317,38 @@
|
||||
"status": "1",
|
||||
"rank": "100",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
"name": "Perplexity",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "pplx-embed-v1-0.6b",
|
||||
"tags": "TEXT EMBEDDING,32000",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "pplx-embed-v1-4b",
|
||||
"tags": "TEXT EMBEDDING,32000",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "pplx-embed-context-v1-0.6b",
|
||||
"tags": "TEXT EMBEDDING,32000",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "pplx-embed-context-v1-4b",
|
||||
"tags": "TEXT EMBEDDING,32000",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "embedding"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ A complete list of models supported by RAGFlow, which will continue to expand.
|
||||
| OpenAI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| OpenAI-API-Compatible | :heavy_check_mark: | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
| OpenRouter | :heavy_check_mark: | :heavy_check_mark: | | | | | |
|
||||
| Perplexity | | :heavy_check_mark: | | | | | |
|
||||
| Replicate | :heavy_check_mark: | | | | :heavy_check_mark: | | |
|
||||
| PPIO | :heavy_check_mark: | | | | | | |
|
||||
| SILICONFLOW | :heavy_check_mark: | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: | |
|
||||
|
||||
@ -113,7 +113,7 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
try:
|
||||
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
@ -358,7 +358,7 @@ class JinaMultiVecEmbed(Base):
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list[str|bytes], task="retrieval.passage"):
|
||||
def encode(self, texts: list[str | bytes], task="retrieval.passage"):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
@ -370,9 +370,9 @@ class JinaMultiVecEmbed(Base):
|
||||
img_b64s = None
|
||||
try:
|
||||
base64.b64decode(text, validate=True)
|
||||
img_b64s = text.decode('utf8')
|
||||
img_b64s = text.decode("utf8")
|
||||
except Exception:
|
||||
img_b64s = base64.b64encode(text).decode('utf8')
|
||||
img_b64s = base64.b64encode(text).decode("utf8")
|
||||
input.append({"image": img_b64s}) # base64 encoded image
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
||||
@ -380,20 +380,20 @@ class JinaMultiVecEmbed(Base):
|
||||
data["return_multivector"] = True
|
||||
|
||||
if "v3" in self.model_name or "v4" in self.model_name:
|
||||
data['task'] = task
|
||||
data['truncate'] = True
|
||||
data["task"] = task
|
||||
data["truncate"] = True
|
||||
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
for d in res['data']:
|
||||
if data.get("return_multivector", False): # v4
|
||||
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
|
||||
for d in res["data"]:
|
||||
if data.get("return_multivector", False): # v4
|
||||
token_embs = np.asarray(d["embeddings"], dtype=np.float32)
|
||||
chunk_emb = token_embs.mean(axis=0)
|
||||
|
||||
else:
|
||||
# v2/v3
|
||||
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
|
||||
chunk_emb = np.asarray(d["embedding"], dtype=np.float32)
|
||||
|
||||
ress.append(chunk_emb)
|
||||
|
||||
@ -444,6 +444,7 @@ class MistralEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
import time
|
||||
import random
|
||||
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
@ -462,6 +463,7 @@ class BedrockEmbed(Base):
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
|
||||
# `key` protocol (backend stores as JSON string in `api_key`):
|
||||
# - Must decode into a dict.
|
||||
# - Required: `auth_mode`, `bedrock_region`.
|
||||
@ -497,10 +499,9 @@ class BedrockEmbed(Base):
|
||||
aws_secret_access_key=creds["SecretAccessKey"],
|
||||
aws_session_token=creds["SessionToken"],
|
||||
)
|
||||
else: # assume_role
|
||||
else: # assume_role
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
embeddings = []
|
||||
@ -1038,6 +1039,7 @@ class GiteeEmbed(SILICONFLOWEmbed):
|
||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class DeepInfraEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
@ -1064,6 +1066,7 @@ class CometAPIEmbed(OpenAIEmbed):
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class DeerAPIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
@ -1081,16 +1084,90 @@ class JiekouAIEmbed(OpenAIEmbed):
|
||||
base_url = "https://api.jiekou.ai/openai/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class RAGconEmbed(OpenAIEmbed):
|
||||
"""
|
||||
RAGcon Embedding Provider - routes through LiteLLM proxy
|
||||
|
||||
|
||||
Default Base URL: https://connect.ragcon.ai/v1
|
||||
"""
|
||||
|
||||
_FACTORY_NAME = "RAGcon"
|
||||
|
||||
|
||||
def __init__(self, key, model_name="text-embedding-3-small", base_url=None):
|
||||
if not base_url:
|
||||
base_url = "https://connect.ragcon.com/v1"
|
||||
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class PerplexityEmbed(Base):
|
||||
_FACTORY_NAME = "Perplexity"
|
||||
|
||||
def __init__(self, key, model_name="pplx-embed-v1-0.6b", base_url="https://api.perplexity.ai"):
|
||||
if not base_url:
|
||||
base_url = "https://api.perplexity.ai"
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_int8(b64_str):
|
||||
raw = base64.b64decode(b64_str)
|
||||
return np.frombuffer(raw, dtype=np.int8).astype(np.float32)
|
||||
|
||||
def _is_contextualized(self):
|
||||
return "context" in self.model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 512
|
||||
ress = []
|
||||
token_count = 0
|
||||
|
||||
if self._is_contextualized():
|
||||
url = f"{self.base_url}/v1/contextualizedembeddings"
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": [[chunk] for chunk in batch],
|
||||
"encoding_format": "base64_int8",
|
||||
}
|
||||
response = requests.post(url, headers=self.headers, json=payload)
|
||||
try:
|
||||
res = response.json()
|
||||
for doc in res["data"]:
|
||||
for chunk_emb in doc["data"]:
|
||||
ress.append(self._decode_base64_int8(chunk_emb["embedding"]))
|
||||
token_count += res.get("usage", {}).get("total_tokens", 0)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response.text}")
|
||||
else:
|
||||
url = f"{self.base_url}/v1/embeddings"
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": batch,
|
||||
"encoding_format": "base64_int8",
|
||||
}
|
||||
response = requests.post(url, headers=self.headers, json=payload)
|
||||
try:
|
||||
res = response.json()
|
||||
for d in res["data"]:
|
||||
ress.append(self._decode_base64_int8(d["embedding"]))
|
||||
token_count += res.get("usage", {}).get("total_tokens", 0)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response.text}")
|
||||
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
0
test/unit_test/rag/llm/__init__.py
Normal file
0
test/unit_test/rag/llm/__init__.py
Normal file
61
test/unit_test/rag/llm/conftest.py
Normal file
61
test/unit_test/rag/llm/conftest.py
Normal file
@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Prevent rag.llm.__init__ from running its heavy auto-discovery loop.
|
||||
|
||||
The __init__.py dynamically imports ALL model modules (chat_model,
|
||||
cv_model, ocr_model, etc.), which pull in deepdoc, xgboost, torch,
|
||||
and other heavy native deps. We pre-install a lightweight stub for
|
||||
the rag.llm package so that `from rag.llm.embedding_model import X`
|
||||
works without triggering the full init.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Resolve the real path to rag/llm/ so sub-module imports can find files
|
||||
_RAGFLOW_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
_RAG_LLM_DIR = os.path.join(_RAGFLOW_ROOT, "rag", "llm")
|
||||
|
||||
|
||||
def _install_rag_llm_stub():
|
||||
"""Replace rag.llm with a minimal package stub if not yet loaded.
|
||||
|
||||
The stub has __path__ pointing to the real rag/llm/ directory so that
|
||||
`from rag.llm.embedding_model import X` resolves to the actual file,
|
||||
but the __init__.py auto-discovery loop is skipped.
|
||||
"""
|
||||
if "rag.llm" in sys.modules:
|
||||
return
|
||||
|
||||
# Create a stub rag.llm package that does NOT run the real __init__
|
||||
llm_pkg = types.ModuleType("rag.llm")
|
||||
llm_pkg.__path__ = [_RAG_LLM_DIR]
|
||||
llm_pkg.__package__ = "rag.llm"
|
||||
# Provide empty dicts for the mappings the real __init__ would build
|
||||
llm_pkg.EmbeddingModel = {}
|
||||
llm_pkg.ChatModel = {}
|
||||
llm_pkg.CvModel = {}
|
||||
llm_pkg.RerankModel = {}
|
||||
llm_pkg.Seq2txtModel = {}
|
||||
llm_pkg.TTSModel = {}
|
||||
llm_pkg.OcrModel = {}
|
||||
sys.modules["rag.llm"] = llm_pkg
|
||||
|
||||
|
||||
_install_rag_llm_stub()
|
||||
250
test/unit_test/rag/llm/test_perplexity_embed.py
Normal file
250
test/unit_test/rag/llm/test_perplexity_embed.py
Normal file
@ -0,0 +1,250 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from rag.llm.embedding_model import PerplexityEmbed
|
||||
|
||||
|
||||
def _make_b64_int8(values):
|
||||
"""Helper: encode a list of int8 values to base64 string."""
|
||||
arr = np.array(values, dtype=np.int8)
|
||||
return base64.b64encode(arr.tobytes()).decode()
|
||||
|
||||
|
||||
def _mock_standard_response(embeddings_b64, total_tokens=10):
|
||||
"""Build a mock JSON response for the standard embeddings endpoint."""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"object": "embedding", "index": i, "embedding": emb} for i, emb in enumerate(embeddings_b64)],
|
||||
"model": "pplx-embed-v1-0.6b",
|
||||
"usage": {"total_tokens": total_tokens},
|
||||
}
|
||||
|
||||
|
||||
def _mock_contextualized_response(docs_embeddings_b64, total_tokens=20):
|
||||
"""Build a mock JSON response for the contextualized embeddings endpoint."""
|
||||
data = []
|
||||
for doc_idx, chunks in enumerate(docs_embeddings_b64):
|
||||
data.append(
|
||||
{
|
||||
"index": doc_idx,
|
||||
"data": [{"object": "embedding", "index": chunk_idx, "embedding": emb} for chunk_idx, emb in enumerate(chunks)],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": "pplx-embed-context-v1-0.6b",
|
||||
"usage": {"total_tokens": total_tokens},
|
||||
}
|
||||
|
||||
|
||||
class TestPerplexityEmbedInit:
|
||||
def test_default_base_url(self):
|
||||
embed = PerplexityEmbed("test-key", "pplx-embed-v1-0.6b")
|
||||
assert embed.base_url == "https://api.perplexity.ai"
|
||||
assert embed.api_key == "test-key"
|
||||
assert embed.model_name == "pplx-embed-v1-0.6b"
|
||||
|
||||
def test_custom_base_url(self):
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-4b", base_url="https://custom.api.com/")
|
||||
assert embed.base_url == "https://custom.api.com"
|
||||
|
||||
def test_empty_base_url_uses_default(self):
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b", base_url="")
|
||||
assert embed.base_url == "https://api.perplexity.ai"
|
||||
|
||||
def test_auth_header(self):
|
||||
embed = PerplexityEmbed("my-secret-key", "pplx-embed-v1-0.6b")
|
||||
assert embed.headers["Authorization"] == "Bearer my-secret-key"
|
||||
|
||||
|
||||
class TestPerplexityEmbedModelDetection:
|
||||
def test_standard_model_not_contextualized(self):
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
assert not embed._is_contextualized()
|
||||
|
||||
def test_standard_4b_not_contextualized(self):
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-4b")
|
||||
assert not embed._is_contextualized()
|
||||
|
||||
def test_contextualized_0_6b(self):
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
|
||||
assert embed._is_contextualized()
|
||||
|
||||
def test_contextualized_4b(self):
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b")
|
||||
assert embed._is_contextualized()
|
||||
|
||||
|
||||
class TestDecodeBase64Int8:
|
||||
def test_basic_decode(self):
|
||||
values = [-1, 0, 1, 127]
|
||||
b64 = _make_b64_int8(values)
|
||||
result = PerplexityEmbed._decode_base64_int8(b64)
|
||||
expected = np.array(values, dtype=np.float32)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_empty_decode(self):
|
||||
b64 = base64.b64encode(b"").decode()
|
||||
result = PerplexityEmbed._decode_base64_int8(b64)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_full_range(self):
|
||||
values = list(range(-128, 128))
|
||||
b64 = _make_b64_int8(values)
|
||||
result = PerplexityEmbed._decode_base64_int8(b64)
|
||||
expected = np.array(values, dtype=np.float32)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_output_dtype_is_float32(self):
|
||||
b64 = _make_b64_int8([1, 2, 3])
|
||||
result = PerplexityEmbed._decode_base64_int8(b64)
|
||||
assert result.dtype == np.float32
|
||||
|
||||
|
||||
class TestPerplexityEmbedStandardEncode:
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_single_text(self, mock_post):
|
||||
emb_b64 = _make_b64_int8([10, 20, 30])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([emb_b64], total_tokens=5)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
result, tokens = embed.encode(["hello"])
|
||||
|
||||
assert result.shape == (1, 3)
|
||||
np.testing.assert_array_equal(result[0], np.array([10, 20, 30], dtype=np.float32))
|
||||
assert tokens == 5
|
||||
mock_post.assert_called_once()
|
||||
call_url = mock_post.call_args[0][0]
|
||||
assert call_url == "https://api.perplexity.ai/v1/embeddings"
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_multiple_texts(self, mock_post):
|
||||
emb1 = _make_b64_int8([1, 2])
|
||||
emb2 = _make_b64_int8([3, 4])
|
||||
emb3 = _make_b64_int8([5, 6])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([emb1, emb2, emb3], total_tokens=15)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
result, tokens = embed.encode(["a", "b", "c"])
|
||||
|
||||
assert result.shape == (3, 2)
|
||||
assert tokens == 15
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_sends_correct_payload(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([_make_b64_int8([1])], total_tokens=1)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-4b")
|
||||
embed.encode(["test text"])
|
||||
|
||||
call_kwargs = mock_post.call_args
|
||||
payload = call_kwargs[1]["json"]
|
||||
assert payload["model"] == "pplx-embed-v1-4b"
|
||||
assert payload["input"] == ["test text"]
|
||||
assert payload["encoding_format"] == "base64_int8"
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_api_error_raises(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.side_effect = Exception("Invalid JSON")
|
||||
mock_resp.text = "Internal Server Error"
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
with pytest.raises(Exception, match="Error"):
|
||||
embed.encode(["hello"])
|
||||
|
||||
|
||||
class TestPerplexityEmbedContextualizedEncode:
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_contextualized_encode(self, mock_post):
|
||||
emb1 = _make_b64_int8([10, 20])
|
||||
emb2 = _make_b64_int8([30, 40])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_contextualized_response([[emb1], [emb2]], total_tokens=12)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
|
||||
result, tokens = embed.encode(["chunk1", "chunk2"])
|
||||
|
||||
assert result.shape == (2, 2)
|
||||
np.testing.assert_array_equal(result[0], np.array([10, 20], dtype=np.float32))
|
||||
np.testing.assert_array_equal(result[1], np.array([30, 40], dtype=np.float32))
|
||||
assert tokens == 12
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_contextualized_uses_correct_endpoint(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b")
|
||||
embed.encode(["chunk"])
|
||||
|
||||
call_url = mock_post.call_args[0][0]
|
||||
assert call_url == "https://api.perplexity.ai/v1/contextualizedembeddings"
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_contextualized_sends_nested_input(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
|
||||
embed.encode(["text1"])
|
||||
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert payload["input"] == [["text1"]]
|
||||
assert payload["model"] == "pplx-embed-context-v1-0.6b"
|
||||
|
||||
|
||||
class TestPerplexityEmbedEncodeQueries:
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_queries_returns_single_vector(self, mock_post):
|
||||
emb = _make_b64_int8([5, 10, 15, 20])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([emb], total_tokens=3)
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
result, tokens = embed.encode_queries("search query")
|
||||
|
||||
assert result.shape == (4,)
|
||||
np.testing.assert_array_equal(result, np.array([5, 10, 15, 20], dtype=np.float32))
|
||||
assert tokens == 3
|
||||
|
||||
|
||||
class TestPerplexityEmbedFactoryRegistration:
|
||||
def test_factory_name(self):
|
||||
assert PerplexityEmbed._FACTORY_NAME == "Perplexity"
|
||||
|
||||
def test_is_subclass_of_base(self):
|
||||
from rag.llm.embedding_model import Base
|
||||
|
||||
assert issubclass(PerplexityEmbed, Base)
|
||||
@ -65,6 +65,7 @@ export enum LLMFactory {
|
||||
N1n = 'n1n',
|
||||
Avian = 'Avian',
|
||||
RAGcon = 'RAGcon',
|
||||
Perplexity = 'Perplexity',
|
||||
}
|
||||
|
||||
// Please lowercase the file name
|
||||
@ -135,6 +136,7 @@ export const IconMap = {
|
||||
[LLMFactory.N1n]: 'n1n',
|
||||
[LLMFactory.Avian]: 'avian',
|
||||
[LLMFactory.RAGcon]: 'ragcon',
|
||||
[LLMFactory.Perplexity]: 'perplexity',
|
||||
};
|
||||
|
||||
export const APIMapUrl = {
|
||||
@ -188,4 +190,6 @@ export const APIMapUrl = {
|
||||
[LLMFactory.PaddleOCR]: 'https://www.paddleocr.ai/latest/',
|
||||
[LLMFactory.N1n]: 'https://docs.n1n.ai',
|
||||
[LLMFactory.Avian]: 'https://avian.io',
|
||||
[LLMFactory.Perplexity]:
|
||||
'https://docs.perplexity.ai/docs/embeddings/quickstart',
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user