mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-08 08:07:21 +08:00
Go: implement Rerank in LocalAI driver (#14813)
### What problem does this PR solve? The LocalAI Go driver landed in #14809 and Embed landed in #14811. `Rerank` was left as a stub that returns `"not implemented"`. This PR fills the gap. LocalAI exposes a public rerank endpoint at `<tenant-url>/v1/rerank` with a Cohere-shaped request and response (`{model, query, documents, top_n}` → `{results: [{index, relevance_score}]}`). The Python side has had `LocalAIRerank` in `rag/llm/rerank_model.py` for a long time. Until this PR, a tenant who wanted to use LocalAI for reranking in the Go layer got `"not implemented"`. ### What this PR includes - `conf/models/localai.json`: add `"rerank": "rerank"` under `url_suffix` so the driver can build the URL from config. This matches the `URLSuffix.Rerank` field already used by aliyun and siliconflow. - `internal/entity/models/localai.go`: replace the `Rerank` stub with a real implementation that POSTs to `/v1/rerank`. Adds local request/response types `localAIRerankRequest` and `localAIRerankResponse`. No factory change. No interface change. ### How the implementation works - Validate the model name and resolve the tenant-supplied base URL with the existing `resolveBaseURL` helper. - Wrap the request with `context.WithTimeout(nonStreamCallTimeout)` so the call has a clear deadline. Same pattern `ChatWithMessages`, `ListModels`, and `Embed` already use in this file. - Only set the `Authorization` header when a non-empty API key was supplied. LocalAI accepts an empty key by default, so this preserves the optional-auth contract. - Default `top_n` to `len(documents)` when `rerankConfig.TopN == 0`, matching the existing Aliyun and SiliconFlow rerank implementations. - Validate every `results[].index` against `len(documents)`. If the upstream returns an out-of-range index, fail clearly instead of silently writing past the slice. - An empty `documents` slice returns `&RerankResponse{}` with no HTTP call. - Non-200 responses propagate the upstream status line and body. ### Note on stacking This PR builds on #14809 (LocalAI driver) and #14811 (LocalAI Embed). Until both merge, this PR's diff on GitHub will include all three commits. After #14809 and #14811 land on `main`, GitHub will auto-reduce this PR to only the `Rerank` changes (one commit, ~99 line diff in `localai.go` plus 1 line in `localai.json`). ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - `go build ./internal/entity/models/...` returns exit 0 on go 1.25 (the `go.mod` minimum). - The full method set on `LocalAIModel` still matches the `ModelDriver` interface. - Pattern parity with the existing Aliyun Rerank (`internal/entity/models/aliyun.go`) and SiliconFlow Rerank (`internal/entity/models/siliconflow.go`) implementations. Closes #14812 Depends on #14809, #14811 Tracking: #14736 Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
10
conf/models/localai.json
Normal file
10
conf/models/localai.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"name": "localai",
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models",
|
||||
"embedding": "embeddings",
|
||||
"rerank": "rerank"
|
||||
},
|
||||
"class": "local"
|
||||
}
|
||||
@ -83,6 +83,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
|
||||
return NewBaichuanModel(baseURL, urlSuffix), nil
|
||||
case "jina":
|
||||
return NewJinaModel(baseURL, urlSuffix), nil
|
||||
case "localai":
|
||||
return NewLocalAIModel(baseURL, urlSuffix), nil
|
||||
case "longcat":
|
||||
return NewLongCatModel(baseURL, urlSuffix), nil
|
||||
case "novita":
|
||||
|
||||
825
internal/entity/models/localai.go
Normal file
825
internal/entity/models/localai.go
Normal file
@ -0,0 +1,825 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// localAIStreamIdleTimeout bounds how long ChatStreamlyWithSender will
|
||||
// wait between SSE chunks before assuming the upstream has stalled and
|
||||
// aborting the request. A local LLM normally emits at least one token
|
||||
// every few seconds; 60s is generous enough to never break a working
|
||||
// stream but tight enough to bound a worst-case mid-body hang.
|
||||
//
|
||||
// var (not const) so tests can lower it without waiting a real minute.
|
||||
var localAIStreamIdleTimeout = 60 * time.Second
|
||||
|
||||
// LocalAIModel implements ModelDriver for LocalAI, a self-hosted
|
||||
// OpenAI-compatible inference server (https://localai.io).
|
||||
//
|
||||
// Unlike cloud providers, LocalAI runs on a tenant-supplied base URL
|
||||
// (for example http://127.0.0.1:8080/v1). The driver therefore reads
|
||||
// the base URL from the per-instance map at call time and does not
|
||||
// assume a "default" entry. The API key is optional: LocalAI accepts
|
||||
// an empty key by default, and the driver only sets the Authorization
|
||||
// header when a non-empty key was supplied.
|
||||
type LocalAIModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewLocalAIModel creates a new LocalAI model instance.
|
||||
//
|
||||
// We clone http.DefaultTransport so we keep Go's defaults for
|
||||
// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2,
|
||||
// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override
|
||||
// the connection-pool fields we care about.
|
||||
//
|
||||
// The Client itself has no Timeout. http.Client.Timeout would also
|
||||
// cap the time spent reading the response body, which would cut off
|
||||
// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming
|
||||
// callers wrap each request with context.WithTimeout instead.
|
||||
func NewLocalAIModel(baseURL map[string]string, urlSuffix URLSuffix) *LocalAIModel {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConns = 100
|
||||
transport.MaxIdleConnsPerHost = 10
|
||||
transport.IdleConnTimeout = 90 * time.Second
|
||||
transport.DisableCompression = false
|
||||
transport.ResponseHeaderTimeout = 60 * time.Second
|
||||
|
||||
return &LocalAIModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LocalAIModel) NewInstance(baseURL map[string]string) ModelDriver {
|
||||
return NewLocalAIModel(baseURL, l.URLSuffix)
|
||||
}
|
||||
|
||||
func (l *LocalAIModel) Name() string {
|
||||
return "localai"
|
||||
}
|
||||
|
||||
// resolveBaseURL returns the tenant-supplied base URL for the given
|
||||
// region, falling back to the "default" entry, and fails with a clear
|
||||
// message when nothing is configured. LocalAI is self-hosted so the
|
||||
// driver cannot fall back to a public endpoint.
|
||||
func (l *LocalAIModel) resolveBaseURL(region string) (string, error) {
|
||||
if base, ok := l.BaseURL[region]; ok && base != "" {
|
||||
return strings.TrimSuffix(base, "/"), nil
|
||||
}
|
||||
if base, ok := l.BaseURL["default"]; ok && base != "" {
|
||||
return strings.TrimSuffix(base, "/"), nil
|
||||
}
|
||||
return "", fmt.Errorf("localai: missing base URL, configure the local access address (e.g., http://127.0.0.1:8080/v1)")
|
||||
}
|
||||
|
||||
// setAuth sets the Authorization header only when a non-empty API key
|
||||
// is supplied. LocalAI accepts an empty key by default, so sending
|
||||
// "Bearer " (with an empty value) would be wrong in both directions:
|
||||
// some local proxies reject it, and it leaks the fact that the
|
||||
// driver was misconfigured.
|
||||
func setLocalAIAuth(req *http.Request, apiConfig *APIConfig) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
}
|
||||
|
||||
// localAIReasoningFields lists the JSON field names that different
|
||||
// upstream models put their chain-of-thought into. LocalAI is a proxy
|
||||
// that can route to any of these, so the driver tries each in turn:
|
||||
//
|
||||
// - reasoning_content: OpenAI o-series, kimi-k2.6, DeepSeek-R1,
|
||||
// magistral when proxied through an OpenAI-shim
|
||||
// - reasoning: Upstage solar-pro3 (and its proxies)
|
||||
// - thinking: Qwen3 (Ollama-style) and some local llama-r1
|
||||
// variants exposed through LocalAI's OpenAI shim
|
||||
//
|
||||
// The first non-empty match wins. Order matters: reasoning_content is
|
||||
// the OpenAI-conformant name and the most widely used, so it's tried
|
||||
// first.
|
||||
var localAIReasoningFields = []string{"reasoning_content", "reasoning", "thinking"}
|
||||
|
||||
// extractLocalAIReasoning pulls the chain-of-thought out of a message
|
||||
// or delta object regardless of which field name the upstream model
|
||||
// chose. Returns "" when no reasoning field is present or non-string.
|
||||
func extractLocalAIReasoning(m map[string]interface{}) string {
|
||||
for _, k := range localAIReasoningFields {
|
||||
if v, ok := m[k].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// addLocalAIReasoningRequestParams propagates the caller's request-side
|
||||
// reasoning intent into the body. Different upstream models behind
|
||||
// LocalAI accept different parameters:
|
||||
//
|
||||
// - reasoning_effort: OpenAI-compatible reasoning APIs (kimi, magistral,
|
||||
// solar-pro2/pro3, gpt-o-series, R1 proxies)
|
||||
// - enable_thinking: Qwen3 explicit thinking toggle
|
||||
//
|
||||
// Both are emitted when the caller opts in, so the request works
|
||||
// against whichever family the LocalAI instance routes to. A non-
|
||||
// supporting upstream simply ignores the extra field.
|
||||
func addLocalAIReasoningRequestParams(reqBody map[string]interface{}, cfg *ChatConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if cfg.Effort != nil && *cfg.Effort != "" {
|
||||
reqBody["reasoning_effort"] = *cfg.Effort
|
||||
}
|
||||
if cfg.Thinking != nil {
|
||||
reqBody["enable_thinking"] = *cfg.Thinking
|
||||
}
|
||||
}
|
||||
|
||||
// ChatWithMessages sends multiple messages with roles and returns the response.
|
||||
func (l *LocalAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("messages is empty")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := l.resolveBaseURL(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat)
|
||||
|
||||
apiMessages := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
apiMessages[i] = map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": apiMessages,
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
// Note: do NOT propagate chatModelConfig.Stream into the request body
|
||||
// here. ChatWithMessages parses a single JSON response, so stream must
|
||||
// always be off for this code path.
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
// LocalAI is a proxy; emit both reasoning_effort and
|
||||
// enable_thinking so the request works regardless of which
|
||||
// model family the LocalAI instance routes to. See
|
||||
// addLocalAIReasoningRequestParams.
|
||||
addLocalAIReasoningRequestParams(reqBody, chatModelConfig)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
setLocalAIAuth(req, apiConfig)
|
||||
|
||||
resp, err := l.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid choice format")
|
||||
}
|
||||
|
||||
messageMap, ok := firstChoice["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
content, ok := messageMap["content"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content format")
|
||||
}
|
||||
|
||||
// Pull the chain-of-thought from whichever field the upstream model
|
||||
// used. See localAIReasoningFields for the priority order.
|
||||
reasonContent := extractLocalAIReasoning(messageMap)
|
||||
|
||||
return &ChatResponse{
|
||||
Answer: &content,
|
||||
ReasonContent: &reasonContent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends messages and streams the response via the
|
||||
// sender function. The LocalAI SSE stream uses the same shape as OpenAI:
|
||||
// "data:" lines carrying JSON events, with a final "[DONE]" line.
|
||||
func (l *LocalAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender is required")
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return fmt.Errorf("messages is empty")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := l.resolveBaseURL(region)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat)
|
||||
|
||||
apiMessages := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
apiMessages[i] = map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": apiMessages,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
if chatModelConfig != nil {
|
||||
// Refuse to run if the caller explicitly asked for stream=false.
|
||||
// The body of this method only knows how to read SSE, so a
|
||||
// non-SSE JSON response would be parsed as if it were a stream
|
||||
// and produce no chunks. Better to fail clearly.
|
||||
if chatModelConfig.Stream != nil && !*chatModelConfig.Stream {
|
||||
return fmt.Errorf("stream must be true in ChatStreamlyWithSender")
|
||||
}
|
||||
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
// LocalAI is a proxy; emit both reasoning_effort and
|
||||
// enable_thinking so the streaming request works regardless of
|
||||
// which model family the LocalAI instance routes to.
|
||||
addLocalAIReasoningRequestParams(reqBody, chatModelConfig)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// SSE streams are long-lived, so we cannot attach a hard deadline:
|
||||
// a legitimate response may take many minutes to finish on a busy
|
||||
// local model. Instead, wrap the request with WithCancel and run
|
||||
// an idle watchdog below that calls cancel() if no new data has
|
||||
// arrived for streamIdleTimeout. That bounds the worst-case stall
|
||||
// to a known finite window without breaking working long streams.
|
||||
//
|
||||
// Threading a real caller-supplied ctx through the ModelDriver
|
||||
// interface remains a wider follow-up; this is the contained fix.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
setLocalAIAuth(req, apiConfig)
|
||||
|
||||
resp, err := l.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Idle watchdog: every successful Scan resets lastActive. If
|
||||
// streamIdleTimeout passes without a reset, the watchdog calls
|
||||
// cancel(), which closes the underlying connection. The blocking
|
||||
// scanner.Scan() then returns false with the context error in
|
||||
// scanner.Err(), and we surface it to the caller instead of
|
||||
// hanging the goroutine forever.
|
||||
lastActive := time.Now()
|
||||
var lastActiveMu sync.Mutex
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
ticker := time.NewTicker(localAIStreamIdleTimeout / 4)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
lastActiveMu.Lock()
|
||||
idle := now.Sub(lastActive)
|
||||
lastActiveMu.Unlock()
|
||||
if idle >= localAIStreamIdleTimeout {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// SSE parsing: bump the scanner buffer from the 64KB default to 1MB
|
||||
// so we never silently truncate a long data: line.
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
sawTerminal := false
|
||||
for scanner.Scan() {
|
||||
lastActiveMu.Lock()
|
||||
lastActive = time.Now()
|
||||
lastActiveMu.Unlock()
|
||||
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
if data == "[DONE]" {
|
||||
sawTerminal = true
|
||||
break
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err = json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := firstChoice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Reasoning chunk first, content second. When an SSE event
|
||||
// carries both, callers that pipe them to a UI render the
|
||||
// chain-of-thought before the answer for that token, matching
|
||||
// the wire ordering Upstage solar-pro3 and kimi-k2.6 emit.
|
||||
// extractLocalAIReasoning tries reasoning_content, reasoning,
|
||||
// and thinking in that order so this works against whichever
|
||||
// model family LocalAI routes to.
|
||||
if reasoning := extractLocalAIReasoning(delta); reasoning != "" {
|
||||
if err := sender(nil, &reasoning); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
content, ok := delta["content"].(string)
|
||||
if ok && content != "" {
|
||||
if err := sender(&content, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
finishReason, ok := firstChoice["finish_reason"].(string)
|
||||
if ok && finishReason != "" {
|
||||
sawTerminal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
// If the watchdog fired, the context is done; surface that as
|
||||
// a clearer "idle" error instead of leaking the raw
|
||||
// "context canceled" string.
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("localai: stream idle for more than %s, aborted", localAIStreamIdleTimeout)
|
||||
}
|
||||
return fmt.Errorf("failed to scan response body: %w", err)
|
||||
}
|
||||
if !sawTerminal {
|
||||
return fmt.Errorf("localai: stream ended before [DONE] or finish_reason")
|
||||
}
|
||||
|
||||
endOfStream := "[DONE]"
|
||||
if err := sender(&endOfStream, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type localAIEmbeddingData struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type localAIEmbeddingResponse struct {
|
||||
Data []localAIEmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// Embed turns a list of texts into embedding vectors using the LocalAI
|
||||
// /v1/embeddings endpoint. The output has one vector per input, in the
|
||||
// same order the inputs were given.
|
||||
func (l *LocalAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
||||
if len(texts) == 0 {
|
||||
return []EmbeddingData{}, nil
|
||||
}
|
||||
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := l.resolveBaseURL(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
setLocalAIAuth(req, apiConfig)
|
||||
|
||||
resp, err := l.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("LocalAI embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed localAIEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// Reorder by the reported index so the output always lines up with
|
||||
// the input texts, even if the upstream API ever returns items out
|
||||
// of order. A nil slot at the end indicates the upstream did not
|
||||
// return an embedding for that input.
|
||||
embeddings := make([]EmbeddingData, len(texts))
|
||||
filled := make([]bool, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("localai: response index %d out of range for %d inputs", item.Index, len(texts))
|
||||
}
|
||||
if filled[item.Index] {
|
||||
// A malformed response that repeats the same index would
|
||||
// silently overwrite the earlier vector. Fail loudly so
|
||||
// the caller never uses ambiguous output.
|
||||
return nil, fmt.Errorf("localai: duplicate embedding index %d in response", item.Index)
|
||||
}
|
||||
embeddings[item.Index] = EmbeddingData{
|
||||
Embedding: item.Embedding,
|
||||
Index: item.Index,
|
||||
}
|
||||
filled[item.Index] = true
|
||||
}
|
||||
for i, ok := range filled {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("localai: missing embedding for input index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
type localAIRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n"`
|
||||
}
|
||||
|
||||
type localAIRerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between a query and a list of documents
|
||||
// using LocalAI's /v1/rerank endpoint. The response shape is Cohere-style:
|
||||
// {results: [{index, relevance_score}]}. The output is copied into the shared
|
||||
// RerankResponse{Data: []RerankResult{Index, RelevanceScore}} shape that the
|
||||
// rest of the codebase already consumes.
|
||||
func (l *LocalAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := l.resolveBaseURL(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Rerank)
|
||||
|
||||
topN := len(documents)
|
||||
if rerankConfig != nil && rerankConfig.TopN > 0 {
|
||||
topN = rerankConfig.TopN
|
||||
}
|
||||
|
||||
reqBody := localAIRerankRequest{
|
||||
Model: *modelName,
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
TopN: topN,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
setLocalAIAuth(req, apiConfig)
|
||||
|
||||
resp, err := l.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("LocalAI rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed localAIRerankResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
rerankResponse := &RerankResponse{}
|
||||
for _, r := range parsed.Results {
|
||||
if r.Index < 0 || r.Index >= len(documents) {
|
||||
return nil, fmt.Errorf("localai: rerank result index %d out of range for %d documents", r.Index, len(documents))
|
||||
}
|
||||
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
|
||||
Index: r.Index,
|
||||
RelevanceScore: r.RelevanceScore,
|
||||
})
|
||||
}
|
||||
|
||||
return rerankResponse, nil
|
||||
}
|
||||
|
||||
// ListModels returns the list of model ids the running LocalAI instance has
|
||||
// loaded. There is no fixed model list at the SaaS level because LocalAI is
|
||||
// self-hosted; the answer depends on what the tenant has configured.
|
||||
func (l *LocalAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := l.resolveBaseURL(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Models)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
setLocalAIAuth(req, apiConfig)
|
||||
|
||||
resp, err := l.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
data, ok := result["data"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid models list format")
|
||||
}
|
||||
|
||||
models := make([]string, 0)
|
||||
for _, model := range data {
|
||||
modelMap, ok := model.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
modelName, ok := modelMap["id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
models = append(models, modelName)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// Balance is not exposed by LocalAI (it is self-hosted and free), so this
|
||||
// returns "no such method".
|
||||
func (l *LocalAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// CheckConnection runs a lightweight ListModels call to verify the LocalAI
|
||||
// base URL is reachable.
|
||||
func (l *LocalAIModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
_, err := l.ListModels(apiConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TranscribeAudio (ASR): LocalAI can route audio to a Whisper backend
|
||||
// when one is loaded, but the wire shape and driver-side plumbing for
|
||||
// streaming audio uploads is separate from this PR's scope. Stub here
|
||||
// to satisfy the ModelDriver interface; follow-up PR welcome.
|
||||
func (l *LocalAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
|
||||
func (l *LocalAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
|
||||
// AudioSpeech (TTS): same story as TranscribeAudio above.
|
||||
func (l *LocalAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
|
||||
func (l *LocalAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
|
||||
// OCRFile: LocalAI has no OCR pipeline in its OpenAI-compatible surface;
|
||||
// document parsing belongs to a different interface entirely.
|
||||
func (l *LocalAIModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
626
internal/entity/models/localai_test.go
Normal file
626
internal/entity/models/localai_test.go
Normal file
@ -0,0 +1,626 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newLocalAIForTest(baseURL string) *LocalAIModel {
|
||||
return NewLocalAIModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{
|
||||
Chat: "chat/completions",
|
||||
Models: "models",
|
||||
Embedding: "embeddings",
|
||||
Rerank: "rerank",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// withLocalAIIdleTimeout swaps the package-level idle timeout for the
|
||||
// duration of the test. Tests that exercise the stall watchdog use a
|
||||
// sub-second value so they finish quickly.
|
||||
func withLocalAIIdleTimeout(t *testing.T, d time.Duration) {
|
||||
t.Helper()
|
||||
original := localAIStreamIdleTimeout
|
||||
localAIStreamIdleTimeout = d
|
||||
t.Cleanup(func() {
|
||||
localAIStreamIdleTimeout = original
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalAIName(t *testing.T) {
|
||||
l := newLocalAIForTest("http://unused")
|
||||
if got := l.Name(); got != "localai" {
|
||||
t.Errorf("Name()=%q, want %q", got, "localai")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIStreamCancelsOnIdle(t *testing.T) {
|
||||
// The server emits one valid chunk and then stalls. Without the
|
||||
// watchdog, scanner.Scan() would hang forever. With the watchdog
|
||||
// at 200ms, it must return a clear "stream idle" error in well
|
||||
// under a second.
|
||||
withLocalAIIdleTimeout(t, 200*time.Millisecond)
|
||||
|
||||
// hold blocks the handler until the test closes it. Register
|
||||
// close(hold) FIRST so it runs LAST (defers are LIFO) — wait,
|
||||
// that's the opposite. We want close(hold) to run BEFORE
|
||||
// srv.Close() so the handler can return. Use t.Cleanup, which
|
||||
// runs in reverse-registration order: register srv.Close first
|
||||
// so it runs last, then close(hold) so it runs first.
|
||||
hold := make(chan struct{})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
_, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"hi"}}]}`+"\n")
|
||||
f.Flush()
|
||||
}
|
||||
// Hang until either the client disconnects (watchdog cancels
|
||||
// the request, which causes r.Context() to fire) or the test
|
||||
// teardown signals via `hold`.
|
||||
select {
|
||||
case <-hold:
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
t.Cleanup(func() { close(hold) })
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
var got []string
|
||||
var mu sync.Mutex
|
||||
start := time.Now()
|
||||
err := l.ChatStreamlyWithSender("gpt-4",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil,
|
||||
func(content *string, _ *string) error {
|
||||
if content == nil || *content == "" {
|
||||
return nil
|
||||
}
|
||||
mu.Lock()
|
||||
got = append(got, *content)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected an idle-timeout error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "idle for more than") {
|
||||
t.Errorf("expected idle-timeout error, got %v", err)
|
||||
}
|
||||
if elapsed > 5*time.Second {
|
||||
t.Errorf("watchdog did not fire promptly; elapsed=%v", elapsed)
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(got) == 0 || got[0] != "hi" {
|
||||
t.Errorf("expected first chunk before stall, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIStreamCompletesWithoutTriggeringWatchdog(t *testing.T) {
|
||||
// Sanity check: a fast, complete stream should not trip the
|
||||
// watchdog even with a moderately tight idle window.
|
||||
withLocalAIIdleTimeout(t, 500*time.Millisecond)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
f, _ := w.(http.Flusher)
|
||||
_, _ = io.WriteString(w,
|
||||
`data: {"choices":[{"delta":{"content":"a"}}]}`+"\n"+
|
||||
`data: {"choices":[{"delta":{"content":"b"}}]}`+"\n"+
|
||||
`data: {"choices":[{"delta":{},"finish_reason":"stop"}]}`+"\n"+
|
||||
`data: [DONE]`+"\n",
|
||||
)
|
||||
if f != nil {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
var chunks []string
|
||||
err := l.ChatStreamlyWithSender("gpt-4",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil,
|
||||
func(content *string, _ *string) error {
|
||||
if content != nil && *content != "" && *content != "[DONE]" {
|
||||
chunks = append(chunks, *content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("stream: %v", err)
|
||||
}
|
||||
if strings.Join(chunks, "") != "ab" {
|
||||
t.Errorf("chunks=%v want [a b]", chunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIStreamRequiresSender(t *testing.T) {
|
||||
l := newLocalAIForTest("http://unused")
|
||||
err := l.ChatStreamlyWithSender("gpt-4",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "sender is required") {
|
||||
t.Errorf("expected sender-required error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIChatMissingBaseURLFailsClearly(t *testing.T) {
|
||||
// LocalAI has no public default; resolveBaseURL must fail with a
|
||||
// helpful message when neither the requested region nor "default"
|
||||
// is configured.
|
||||
l := NewLocalAIModel(map[string]string{}, URLSuffix{Chat: "chat/completions"})
|
||||
_, err := l.ChatWithMessages("gpt-4",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "missing base URL") {
|
||||
t.Errorf("expected missing-base-URL error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIChatOmitsAuthHeaderWhenKeyEmpty(t *testing.T) {
|
||||
// Optional-auth contract: LocalAI accepts an empty key, so the
|
||||
// driver must NOT send a "Bearer " header in that case.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("expected no Authorization header, got %q", got)
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
resp, err := l.ChatWithMessages("gpt-4",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if *resp.Answer != "ok" {
|
||||
t.Errorf("answer=%q want ok", *resp.Answer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIChatSendsAuthHeaderWhenKeyProvided(t *testing.T) {
|
||||
// And conversely: when a tenant has put LocalAI behind an auth
|
||||
// proxy with a token, the driver does send the Bearer header.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer secret" {
|
||||
t.Errorf("expected Authorization=Bearer secret, got %q", got)
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
key := "secret"
|
||||
_, err := l.ChatWithMessages("gpt-4",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{ApiKey: &key}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIBalanceReturnsNoSuchMethod(t *testing.T) {
|
||||
l := newLocalAIForTest("http://unused")
|
||||
_, err := l.Balance(&APIConfig{})
|
||||
if err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Balance: expected 'no such method', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIEmbedHappyPath(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/embeddings" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"data":[
|
||||
{"embedding":[0.1,0.2],"index":0},
|
||||
{"embedding":[0.3,0.4],"index":1},
|
||||
{"embedding":[0.5,0.6],"index":2}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
model := "text-embedding-ada-002"
|
||||
vecs, err := l.Embed(&model, []string{"a", "b", "c"}, &APIConfig{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
if len(vecs) != 3 {
|
||||
t.Fatalf("len=%d want 3", len(vecs))
|
||||
}
|
||||
if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 {
|
||||
t.Errorf("vecs[1]=%+v", vecs[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIEmbedRejectsDuplicateIndex(t *testing.T) {
|
||||
// CodeRabbit caught that a response repeating data[*].index would
|
||||
// silently overwrite the earlier vector. Verify the driver fails
|
||||
// loudly instead.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, `{"data":[
|
||||
{"embedding":[1],"index":0},
|
||||
{"embedding":[2],"index":0}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
model := "text-embedding-ada-002"
|
||||
_, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") {
|
||||
t.Errorf("expected duplicate-index error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIEmbedRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":7}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
model := "text-embedding-ada-002"
|
||||
_, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "out of range") {
|
||||
t.Errorf("expected out-of-range error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIEmbedRejectsMissingSlot(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":0}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
model := "text-embedding-ada-002"
|
||||
_, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") {
|
||||
t.Errorf("expected missing-slot error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalAIEmbedEmptyInputShortCircuits(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("Embed([]) made an unexpected HTTP call")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
model := "text-embedding-ada-002"
|
||||
vecs, err := l.Embed(&model, []string{}, &APIConfig{}, nil)
|
||||
if err != nil || len(vecs) != 0 {
|
||||
t.Errorf("Embed([])=(%v,%v) want ([],nil)", vecs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- reasoning extraction (multi-field) ----------
|
||||
|
||||
// Table-driven unit coverage of the helper. Mirrors the priority order
|
||||
// reasoning_content > reasoning > thinking declared in
|
||||
// localAIReasoningFields. New upstream field names can be added by
|
||||
// extending that slice without touching call sites.
|
||||
func TestExtractLocalAIReasoning(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in map[string]interface{}
|
||||
want string
|
||||
}{
|
||||
{"empty map", map[string]interface{}{}, ""},
|
||||
{"reasoning_content wins", map[string]interface{}{
|
||||
"reasoning_content": "rc",
|
||||
"reasoning": "r",
|
||||
"thinking": "t",
|
||||
}, "rc"},
|
||||
{"reasoning when no reasoning_content", map[string]interface{}{
|
||||
"reasoning": "r",
|
||||
"thinking": "t",
|
||||
}, "r"},
|
||||
{"thinking when only that is set", map[string]interface{}{
|
||||
"thinking": "qwen3-thought",
|
||||
}, "qwen3-thought"},
|
||||
{"empty string treated as absent", map[string]interface{}{
|
||||
"reasoning_content": "",
|
||||
"reasoning": "fallback",
|
||||
}, "fallback"},
|
||||
{"non-string ignored", map[string]interface{}{
|
||||
"reasoning_content": 42,
|
||||
"reasoning": "fallback",
|
||||
}, "fallback"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := extractLocalAIReasoning(tc.in)
|
||||
if got != tc.want {
|
||||
t.Errorf("got=%q want=%q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Non-streaming chat against an upstream that puts the trace in
|
||||
// message.reasoning_content (kimi-k2.6, OpenAI o-series, DeepSeek-R1
|
||||
// when proxied through OpenAI-shim). The driver must surface it on
|
||||
// ChatResponse.ReasonContent.
|
||||
func TestLocalAIChatExtractsReasoningContent(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{
|
||||
"role":"assistant",
|
||||
"content":"The answer is 12.",
|
||||
"reasoning_content":"15% = 0.15; 0.15 * 80 = 12."
|
||||
}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
resp, err := l.ChatWithMessages("kimi-k2.6",
|
||||
[]Message{{Role: "user", Content: "15% of 80?"}},
|
||||
&APIConfig{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if *resp.Answer != "The answer is 12." {
|
||||
t.Errorf("Answer=%q", *resp.Answer)
|
||||
}
|
||||
if *resp.ReasonContent != "15% = 0.15; 0.15 * 80 = 12." {
|
||||
t.Errorf("ReasonContent=%q", *resp.ReasonContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Non-streaming chat that uses message.thinking (Qwen3 via Ollama-shim
|
||||
// inside LocalAI). The driver must surface it on ReasonContent too.
|
||||
func TestLocalAIChatExtractsThinking(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{
|
||||
"role":"assistant",
|
||||
"content":"12",
|
||||
"thinking":"Compute 15/100 * 80"
|
||||
}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
resp, err := l.ChatWithMessages("qwen3-32b",
|
||||
[]Message{{Role: "user", Content: "15% of 80?"}},
|
||||
&APIConfig{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if *resp.ReasonContent != "Compute 15/100 * 80" {
|
||||
t.Errorf("ReasonContent=%q want %q", *resp.ReasonContent, "Compute 15/100 * 80")
|
||||
}
|
||||
}
|
||||
|
||||
// Regression net: a response with no reasoning field at all (any
|
||||
// non-reasoning model) must produce empty ReasonContent without
|
||||
// crashing or erroring.
|
||||
func TestLocalAIChatHandlesAbsentReasoning(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{
|
||||
"role":"assistant","content":"hello"
|
||||
}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
resp, err := l.ChatWithMessages("llama-3-8b-instruct",
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
&APIConfig{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if *resp.Answer != "hello" {
|
||||
t.Errorf("Answer=%q", *resp.Answer)
|
||||
}
|
||||
if *resp.ReasonContent != "" {
|
||||
t.Errorf("ReasonContent=%q want empty", *resp.ReasonContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming chat where the upstream interleaves delta.reasoning_content
|
||||
// chunks and delta.content chunks (kimi-k2.6, o-series shape).
|
||||
// Reasoning must reach the sender's 2nd arg, content the 1st.
|
||||
func TestLocalAIStreamExtractsReasoningContentDelta(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w,
|
||||
`data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+
|
||||
`data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 1. "}}]}`+"\n"+
|
||||
`data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 2."}}]}`+"\n"+
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"Answer."},"finish_reason":"stop"}]}`+"\n"+
|
||||
`data: [DONE]`+"\n",
|
||||
)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
var content, reasoning []string
|
||||
err := l.ChatStreamlyWithSender("kimi-k2.6",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil,
|
||||
func(c *string, r *string) error {
|
||||
if c != nil && r != nil {
|
||||
t.Errorf("sender called with both args non-nil")
|
||||
}
|
||||
if r != nil && *r != "" {
|
||||
reasoning = append(reasoning, *r)
|
||||
}
|
||||
if c != nil && *c != "" && *c != "[DONE]" {
|
||||
content = append(content, *c)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("stream: %v", err)
|
||||
}
|
||||
if got := strings.Join(reasoning, ""); got != "step 1. step 2." {
|
||||
t.Errorf("reasoning joined=%q", got)
|
||||
}
|
||||
if got := strings.Join(content, ""); got != "Answer." {
|
||||
t.Errorf("content joined=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming chat where the upstream uses delta.thinking (Qwen3 shape).
|
||||
// The same handler must work.
|
||||
func TestLocalAIStreamExtractsThinkingDelta(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w,
|
||||
`data: {"choices":[{"index":0,"delta":{"thinking":"qwen-trace"}}]}`+"\n"+
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"final"},"finish_reason":"stop"}]}`+"\n"+
|
||||
`data: [DONE]`+"\n",
|
||||
)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
var got []string
|
||||
err := l.ChatStreamlyWithSender("qwen3-32b",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, nil,
|
||||
func(c *string, r *string) error {
|
||||
if r != nil && *r != "" {
|
||||
got = append(got, "R:"+*r)
|
||||
}
|
||||
if c != nil && *c != "" && *c != "[DONE]" {
|
||||
got = append(got, "C:"+*c)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("stream: %v", err)
|
||||
}
|
||||
want := []string{"R:qwen-trace", "C:final"}
|
||||
if len(got) != 2 || got[0] != want[0] || got[1] != want[1] {
|
||||
t.Errorf("seq=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Request-side: ChatConfig.Effort must flow into request body as
|
||||
// reasoning_effort.
|
||||
func TestLocalAIChatPropagatesReasoningEffort(t *testing.T) {
|
||||
var seen map[string]interface{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(raw, &seen); err != nil {
|
||||
t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
effort := "high"
|
||||
_, err := l.ChatWithMessages("kimi-k2.6",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, &ChatConfig{Effort: &effort})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if seen["reasoning_effort"] != "high" {
|
||||
t.Errorf("reasoning_effort=%v want high", seen["reasoning_effort"])
|
||||
}
|
||||
if _, present := seen["enable_thinking"]; present {
|
||||
t.Errorf("enable_thinking should be absent when Thinking nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Request-side: ChatConfig.Thinking must flow into request body as
|
||||
// enable_thinking (Qwen3-style).
|
||||
func TestLocalAIChatPropagatesEnableThinking(t *testing.T) {
|
||||
var seen map[string]interface{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(raw, &seen); err != nil {
|
||||
t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
think := true
|
||||
_, err := l.ChatWithMessages("qwen3-32b",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, &ChatConfig{Thinking: &think})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if seen["enable_thinking"] != true {
|
||||
t.Errorf("enable_thinking=%v want true", seen["enable_thinking"])
|
||||
}
|
||||
}
|
||||
|
||||
// Stream request also propagates the reasoning params.
|
||||
func TestLocalAIStreamPropagatesReasoningParams(t *testing.T) {
|
||||
var seen map[string]interface{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(raw, &seen); err != nil {
|
||||
t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w,
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"x"},"finish_reason":"stop"}]}`+"\n"+
|
||||
`data: [DONE]`+"\n",
|
||||
)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := newLocalAIForTest(srv.URL)
|
||||
effort := "medium"
|
||||
think := true
|
||||
err := l.ChatStreamlyWithSender("kimi-k2.6",
|
||||
[]Message{{Role: "user", Content: "x"}},
|
||||
&APIConfig{}, &ChatConfig{Effort: &effort, Thinking: &think},
|
||||
func(*string, *string) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("stream: %v", err)
|
||||
}
|
||||
if seen["reasoning_effort"] != "medium" {
|
||||
t.Errorf("reasoning_effort=%v want medium", seen["reasoning_effort"])
|
||||
}
|
||||
if seen["enable_thinking"] != true {
|
||||
t.Errorf("enable_thinking=%v want true", seen["enable_thinking"])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user