From 8ff623fbc44e92e3faf32ae392e0ff7c2c8ded5f Mon Sep 17 00:00:00 2001 From: Jack Storment <88656337+jack-stormentswe@users.noreply.github.com> Date: Mon, 11 May 2026 06:50:15 +0200 Subject: [PATCH] Go: implement Encode (embeddings) in Ollama driver (#14664) ### What problem does this PR solve? The Ollama Go driver shipped with a stub \`Encode\` method that returned \`no such method\`, even though Ollama is one of the most common local LLM runners and exposes an OpenAI-compatible embeddings endpoint at \`/v1/embeddings\`. Ollama users routinely run local embedding models such as \`nomic-embed-text\`, \`mxbai-embed-large\`, or \`bge-m3\`. Pulled with \`ollama pull \` and served on the same \`/v1\` namespace as chat. The existing \`ListModels\` already discovers them, but because \`Encode\` was a stub, a tenant who picked one of these models in the Go layer could not actually run an embedding call. ### What this PR includes - \`conf/models/ollama.json\`: add \`\"embedding\": \"embeddings\"\` under \`url_suffix\` so the driver can build the URL from config. - \`internal/entity/models/ollama.go\`: replace the \`Encode\` stub with a real implementation. Adds a small local response type that matches the OpenAI-compatible shape. No factory change. No interface change. ### How the driver works - Validate the model name. The API key is optional for local Ollama, so the Authorization header is only set when both \`apiConfig\` and \`ApiKey\` are non-nil and non-empty, the same pattern the recently merged CheckConnection PR (#14614) uses. - Resolve the region with a default fallback. Return a clear "missing base URL" error when the user has not configured the local access address yet. - Use a per-call \`context.WithTimeout(30s)\` and \`http.NewRequestWithContext\`, the same pattern the merged Aliyun Encode (#14647) uses. - Send \`{model, input: [texts]}\` in one request. - Parse \`data[*].embedding\` and copy each slice into a \`[][]float64\` indexed by \`data[*].index\`, so the output order matches the input order. - Handle both \`float64\` and \`float32\` element types. - Empty input returns \`[][]float64{}\` with no HTTP call. - Length mismatch between input and result, out-of-range index, and any missing slot all return clear errors instead of silent zero vectors. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - \`go build ./internal/entity/models/...\` in a clean go 1.25 image returns exit 0. - The full method set on \`OllamaModel\` still matches the \`ModelDriver\` interface. - Pattern parity with the merged Aliyun Encode (#14647) and the existing SiliconFlow Encode. Closes #14662 --- conf/models/ollama.json | 3 +- internal/entity/models/factory.go | 2 + internal/entity/models/ollama.go | 108 +++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/conf/models/ollama.json b/conf/models/ollama.json index ed0a1e011..58adb17ef 100644 --- a/conf/models/ollama.json +++ b/conf/models/ollama.json @@ -2,7 +2,8 @@ "name": "ollama", "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings" }, "class": "local" } \ No newline at end of file diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 8475049c5..1c0de11c6 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -57,6 +57,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewXAIModel(baseURL, urlSuffix), nil case "lmstudio": return NewLmStudioModel(baseURL, urlSuffix), nil + case "ollama": + return NewOllamaModel(baseURL, urlSuffix), nil case "openai": return NewOpenAIModel(baseURL, urlSuffix), nil case "nvidia": diff --git a/internal/entity/models/ollama.go b/internal/entity/models/ollama.go index 4e8e42ad0..3b22039c3 100644 --- a/internal/entity/models/ollama.go +++ b/internal/entity/models/ollama.go @@ -3,6 +3,7 @@ package models import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -359,8 +360,113 @@ func (o *OllamaModel) ChatStreamlyWithSender(modelName string, messages []Messag return scanner.Err() } +type ollamaEmbeddingResponse struct { + Data []struct { + Index int `json:"index"` + Embedding []interface{} `json:"embedding"` + } `json:"data"` +} + func (o *OllamaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("no such method") + if len(texts) == 0 { + return [][]float64{}, nil + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := o.BaseURL[region] + if baseURL == "" { + baseURL = o.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), o.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ollama embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed ollamaEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if len(parsed.Data) != len(texts) { + return nil, fmt.Errorf("ollama embeddings: expected %d results, got %d", len(texts), len(parsed.Data)) + } + + embeddings := make([][]float64, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts)) + } + vec := make([]float64, len(item.Embedding)) + for j, v := range item.Embedding { + switch val := v.(type) { + case float64: + vec[j] = val + case float32: + vec[j] = float64(val) + default: + return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j) + } + } + embeddings[item.Index] = vec + } + + for i, vec := range embeddings { + if vec == nil { + return nil, fmt.Errorf("missing embedding for input at index %d", i) + } + } + + return embeddings, nil } func (o *OllamaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {