mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-22 17:08:23 +08:00
Go: implement Encode (embeddings) in Ollama driver (#14664)
### What problem does this PR solve? The Ollama Go driver shipped with a stub \`Encode\` method that returned \`no such method\`, even though Ollama is one of the most common local LLM runners and exposes an OpenAI-compatible embeddings endpoint at \`/v1/embeddings\`. Ollama users routinely run local embedding models such as \`nomic-embed-text\`, \`mxbai-embed-large\`, or \`bge-m3\`. Pulled with \`ollama pull <model>\` and served on the same \`/v1\` namespace as chat. The existing \`ListModels\` already discovers them, but because \`Encode\` was a stub, a tenant who picked one of these models in the Go layer could not actually run an embedding call. ### What this PR includes - \`conf/models/ollama.json\`: add \`\"embedding\": \"embeddings\"\` under \`url_suffix\` so the driver can build the URL from config. - \`internal/entity/models/ollama.go\`: replace the \`Encode\` stub with a real implementation. Adds a small local response type that matches the OpenAI-compatible shape. No factory change. No interface change. ### How the driver works - Validate the model name. The API key is optional for local Ollama, so the Authorization header is only set when both \`apiConfig\` and \`ApiKey\` are non-nil and non-empty, the same pattern the recently merged CheckConnection PR (#14614) uses. - Resolve the region with a default fallback. Return a clear "missing base URL" error when the user has not configured the local access address yet. - Use a per-call \`context.WithTimeout(30s)\` and \`http.NewRequestWithContext\`, the same pattern the merged Aliyun Encode (#14647) uses. - Send \`{model, input: [texts]}\` in one request. - Parse \`data[*].embedding\` and copy each slice into a \`[][]float64\` indexed by \`data[*].index\`, so the output order matches the input order. - Handle both \`float64\` and \`float32\` element types. - Empty input returns \`[][]float64{}\` with no HTTP call. - Length mismatch between input and result, out-of-range index, and any missing slot all return clear errors instead of silent zero vectors. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - \`go build ./internal/entity/models/...\` in a clean go 1.25 image returns exit 0. - The full method set on \`OllamaModel\` still matches the \`ModelDriver\` interface. - Pattern parity with the merged Aliyun Encode (#14647) and the existing SiliconFlow Encode. Closes #14662
This commit is contained in:
@ -2,7 +2,8 @@
|
||||
"name": "ollama",
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models"
|
||||
"models": "models",
|
||||
"embedding": "embeddings"
|
||||
},
|
||||
"class": "local"
|
||||
}
|
||||
@ -57,6 +57,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
|
||||
return NewXAIModel(baseURL, urlSuffix), nil
|
||||
case "lmstudio":
|
||||
return NewLmStudioModel(baseURL, urlSuffix), nil
|
||||
case "ollama":
|
||||
return NewOllamaModel(baseURL, urlSuffix), nil
|
||||
case "openai":
|
||||
return NewOpenAIModel(baseURL, urlSuffix), nil
|
||||
case "nvidia":
|
||||
|
||||
@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -359,8 +360,113 @@ func (o *OllamaModel) ChatStreamlyWithSender(modelName string, messages []Messag
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
type ollamaEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Index int `json:"index"`
|
||||
Embedding []interface{} `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (o *OllamaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := o.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = o.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), o.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Ollama embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed ollamaEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(parsed.Data) != len(texts) {
|
||||
return nil, fmt.Errorf("ollama embeddings: expected %d results, got %d", len(texts), len(parsed.Data))
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts))
|
||||
}
|
||||
vec := make([]float64, len(item.Embedding))
|
||||
for j, v := range item.Embedding {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
vec[j] = val
|
||||
case float32:
|
||||
vec[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j)
|
||||
}
|
||||
}
|
||||
embeddings[item.Index] = vec
|
||||
}
|
||||
|
||||
for i, vec := range embeddings {
|
||||
if vec == nil {
|
||||
return nil, fmt.Errorf("missing embedding for input at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (o *OllamaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
|
||||
Reference in New Issue
Block a user