mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-08 08:07:21 +08:00
## Summary - Replaces the `"no such method"` stub on `NvidiaModel.Rerank` (`internal/entity/models/nvidia.go`) with a real implementation against NVIDIA NIM's `/ranking` endpoint. - Mirrors the existing Python `NvidiaRerank` class at `rag/llm/rerank_model.py:149-190` for behavior parity: same `passages`/`query.text`/`logit` payload shape; `top_n` set to `len(documents)` so every input gets a score returned in original order (the issue body's spec omitted `top_n`, which would cause silent data loss). - Adds the `"rerank": "ranking"` URL suffix and two NIM rerank model entries (`nvidia/nv-rerankqa-mistral-4b-v3`, `nvidia/llama-3.2-nv-rerankqa-1b-v2`) to `conf/models/nvidia.json` so the picker exposes them. - Follows the same shape as the recently merged Aliyun (#14676), Gitee (#14656), and ZhipuAI (#14608) Rerank implementations: lowercase per-driver request/response types, conversion to the project-wide `RerankResponse{Data: []RerankResult}`, per-call `context.WithTimeout` of 30s. Closes #14720 ## Test plan - [x] `gofmt -l internal/entity/models/nvidia.go` — clean - [x] `go vet ./internal/entity/models/...` — no new errors introduced (the two pre-existing vet errors in `baidu.go:642` and `openrouter.go:566` are unrelated to this PR) - [x] `go build ./internal/entity/models/...` — succeeds - [x] `python3 -c "import json; json.load(open('conf/models/nvidia.json'))"` — JSON valid - [ ] Live smoke test against NVIDIA NIM with a real API key (requires reviewer with NIM credentials) ## Notes for reviewers - The issue body suggested omitting `top_n`. The Python reference includes it (`top_n: len(texts)`), and without it NVIDIA returns only the default top-K rankings rather than scores for every input. This PR follows the Python. - The URL host is `integrate.api.nvidia.com` (kept consistent with the existing chat/embeddings BaseURL in `nvidia.go`), not the legacy `ai.api.nvidia.com` host the Python uses. NIM's unified endpoint accepts the model names as-is, so no per-model URL transform is needed.
196 lines
6.3 KiB
Go
196 lines
6.3 KiB
Go
package models
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func newNvidiaRerankServer(t *testing.T, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
|
|
t.Helper()
|
|
// Use t.Errorf + return inside the handler goroutine; t.Fatalf would
|
|
// only Goexit the handler goroutine and the test would silently pass.
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
t.Errorf("expected POST, got %s", r.Method)
|
|
return
|
|
}
|
|
if r.URL.Path != "/ranking" {
|
|
t.Errorf("expected path=/ranking, got %s", r.URL.Path)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
|
t.Errorf("expected Authorization=Bearer test-key, got %q", got)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
|
t.Errorf("expected Content-Type=application/json, got %q", got)
|
|
return
|
|
}
|
|
raw, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Errorf("failed to read body: %v", err)
|
|
return
|
|
}
|
|
var body map[string]interface{}
|
|
if err := json.Unmarshal(raw, &body); err != nil {
|
|
t.Errorf("invalid JSON body: %v\n%s", err, string(raw))
|
|
return
|
|
}
|
|
handler(t, body, w)
|
|
}))
|
|
}
|
|
|
|
func newNvidiaModelForTest(baseURL string) *NvidiaModel {
|
|
return NewNvidiaModel(
|
|
map[string]string{"default": baseURL},
|
|
URLSuffix{Rerank: "ranking"},
|
|
)
|
|
}
|
|
|
|
func TestNvidiaRerankHappyPath(t *testing.T) {
|
|
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
if body["model"] != "nvidia/nv-rerankqa-mistral-4b-v3" {
|
|
t.Errorf("expected model=nvidia/nv-rerankqa-mistral-4b-v3, got %v", body["model"])
|
|
}
|
|
query, ok := body["query"].(map[string]interface{})
|
|
if !ok || query["text"] != "What is RAPTOR?" {
|
|
t.Errorf("expected query.text=What is RAPTOR?, got %v", body["query"])
|
|
}
|
|
passages, ok := body["passages"].([]interface{})
|
|
if !ok || len(passages) != 3 {
|
|
t.Errorf("expected 3 passages, got %v", body["passages"])
|
|
return
|
|
}
|
|
if body["truncate"] != "END" {
|
|
t.Errorf("expected truncate=END, got %v", body["truncate"])
|
|
}
|
|
if body["top_n"] != float64(3) {
|
|
t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"])
|
|
}
|
|
// Return rankings out of input order to verify Index preservation.
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"rankings": []map[string]interface{}{
|
|
{"index": 2, "logit": 9.5},
|
|
{"index": 0, "logit": 4.25},
|
|
{"index": 1, "logit": 7.8},
|
|
},
|
|
})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newNvidiaModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
|
resp, err := model.Rerank(
|
|
&modelName,
|
|
"What is RAPTOR?",
|
|
[]string{"doc-zero", "doc-one", "doc-two"},
|
|
&APIConfig{ApiKey: &apiKey},
|
|
&RerankConfig{},
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Rerank failed: %v", err)
|
|
}
|
|
if len(resp.Data) != 3 {
|
|
t.Fatalf("expected 3 results, got %d", len(resp.Data))
|
|
}
|
|
want := map[int]float64{0: 4.25, 1: 7.8, 2: 9.5}
|
|
for _, r := range resp.Data {
|
|
if got, ok := want[r.Index]; !ok || got != r.RelevanceScore {
|
|
t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNvidiaRerankTopNClamp(t *testing.T) {
|
|
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
if body["top_n"] != float64(2) {
|
|
t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"])
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{"rankings": []map[string]interface{}{}})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newNvidiaModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
|
if _, err := model.Rerank(
|
|
&modelName, "q",
|
|
[]string{"a", "b", "c", "d"},
|
|
&APIConfig{ApiKey: &apiKey},
|
|
&RerankConfig{TopN: 2},
|
|
); err != nil {
|
|
t.Fatalf("Rerank failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNvidiaRerankEmptyDocuments(t *testing.T) {
|
|
model := newNvidiaModelForTest("http://unused")
|
|
apiKey := "test-key"
|
|
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
|
resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err != nil {
|
|
t.Fatalf("expected nil error for empty documents, got %v", err)
|
|
}
|
|
if len(resp.Data) != 0 {
|
|
t.Errorf("expected empty Data, got %d entries", len(resp.Data))
|
|
}
|
|
}
|
|
|
|
func TestNvidiaRerankRequiresAPIKey(t *testing.T) {
|
|
model := newNvidiaModelForTest("http://unused")
|
|
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
|
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "api key is required") {
|
|
t.Errorf("expected api-key error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNvidiaRerankRequiresModelName(t *testing.T) {
|
|
model := newNvidiaModelForTest("http://unused")
|
|
apiKey := "test-key"
|
|
_, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
|
t.Errorf("expected model-name error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNvidiaRerankRejectsHTTPError(t *testing.T) {
|
|
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newNvidiaModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
|
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "Nvidia rerank API error") {
|
|
t.Errorf("expected API error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNvidiaRerankRejectsOutOfRangeIndex(t *testing.T) {
|
|
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"rankings": []map[string]interface{}{
|
|
{"index": 5, "logit": 1.0}, // out of range for 2-input request
|
|
},
|
|
})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newNvidiaModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
|
_, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") {
|
|
t.Errorf("expected out-of-range error, got %v", err)
|
|
}
|
|
}
|