Go: implement Rerank in LocalAI driver (#14813)

### What problem does this PR solve?

The LocalAI Go driver landed in #14809 and Embed landed in #14811.
`Rerank` was left as a stub that returns `"not implemented"`. This PR
fills the gap.

LocalAI exposes a public rerank endpoint at `<tenant-url>/v1/rerank`
with a Cohere-shaped request and response (`{model, query, documents,
top_n}` → `{results: [{index, relevance_score}]}`). The Python side has
had `LocalAIRerank` in `rag/llm/rerank_model.py` for a long time. Until
this PR, a tenant who wanted to use LocalAI for reranking in the Go
layer got `"not implemented"`.

### What this PR includes

- `conf/models/localai.json`: add `"rerank": "rerank"` under
`url_suffix` so the driver can build the URL from config. This matches
the `URLSuffix.Rerank` field already used by aliyun and siliconflow.
- `internal/entity/models/localai.go`: replace the `Rerank` stub with a
real implementation that POSTs to `/v1/rerank`. Adds local
request/response types `localAIRerankRequest` and
`localAIRerankResponse`.

No factory change. No interface change.

### How the implementation works

- Validate the model name and resolve the tenant-supplied base URL with
the existing `resolveBaseURL` helper.
- Wrap the request with `context.WithTimeout(nonStreamCallTimeout)` so
the call has a clear deadline. Same pattern `ChatWithMessages`,
`ListModels`, and `Embed` already use in this file.
- Only set the `Authorization` header when a non-empty API key was
supplied. LocalAI accepts an empty key by default, so this preserves the
optional-auth contract.
- Default `top_n` to `len(documents)` when `rerankConfig.TopN == 0`,
matching the existing Aliyun and SiliconFlow rerank implementations.
- Validate every `results[].index` against `len(documents)`. If the
upstream returns an out-of-range index, fail clearly instead of silently
writing past the slice.
- An empty `documents` slice returns `&RerankResponse{}` with no HTTP
call.
- Non-200 responses propagate the upstream status line and body.

### Note on stacking

This PR builds on #14809 (LocalAI driver) and #14811 (LocalAI Embed).
Until both merge, this PR's diff on GitHub will include all three
commits. After #14809 and #14811 land on `main`, GitHub will auto-reduce
this PR to only the `Rerank` changes (one commit, ~99 line diff in
`localai.go` plus 1 line in `localai.json`).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

### How was this tested?

- `go build ./internal/entity/models/...` returns exit 0 on go 1.25 (the
`go.mod` minimum).
- The full method set on `LocalAIModel` still matches the `ModelDriver`
interface.
- Pattern parity with the existing Aliyun Rerank
(`internal/entity/models/aliyun.go`) and SiliconFlow Rerank
(`internal/entity/models/siliconflow.go`) implementations.

Closes #14812
Depends on #14809, #14811
Tracking: #14736

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
tmimmanuel
2026-05-13 01:35:19 -10:00
committed by GitHub
parent 3182fd0789
commit 0a4b733b2a
4 changed files with 1463 additions and 0 deletions

10
conf/models/localai.json Normal file
View File

@ -0,0 +1,10 @@
{
"name": "localai",
"url_suffix": {
"chat": "chat/completions",
"models": "models",
"embedding": "embeddings",
"rerank": "rerank"
},
"class": "local"
}

View File

@ -83,6 +83,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
return NewBaichuanModel(baseURL, urlSuffix), nil
case "jina":
return NewJinaModel(baseURL, urlSuffix), nil
case "localai":
return NewLocalAIModel(baseURL, urlSuffix), nil
case "longcat":
return NewLongCatModel(baseURL, urlSuffix), nil
case "novita":

View File

@ -0,0 +1,825 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
)
// localAIStreamIdleTimeout bounds how long ChatStreamlyWithSender will
// wait between SSE chunks before assuming the upstream has stalled and
// aborting the request. A local LLM normally emits at least one token
// every few seconds; 60s is generous enough to never break a working
// stream but tight enough to bound a worst-case mid-body hang.
//
// var (not const) so tests can lower it without waiting a real minute.
var localAIStreamIdleTimeout = 60 * time.Second
// LocalAIModel implements ModelDriver for LocalAI, a self-hosted
// OpenAI-compatible inference server (https://localai.io).
//
// Unlike cloud providers, LocalAI runs on a tenant-supplied base URL
// (for example http://127.0.0.1:8080/v1). The driver therefore reads
// the base URL from the per-instance map at call time and does not
// assume a "default" entry. The API key is optional: LocalAI accepts
// an empty key by default, and the driver only sets the Authorization
// header when a non-empty key was supplied.
type LocalAIModel struct {
BaseURL map[string]string
URLSuffix URLSuffix
httpClient *http.Client
}
// NewLocalAIModel creates a new LocalAI model instance.
//
// We clone http.DefaultTransport so we keep Go's defaults for
// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2,
// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override
// the connection-pool fields we care about.
//
// The Client itself has no Timeout. http.Client.Timeout would also
// cap the time spent reading the response body, which would cut off
// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming
// callers wrap each request with context.WithTimeout instead.
func NewLocalAIModel(baseURL map[string]string, urlSuffix URLSuffix) *LocalAIModel {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.MaxIdleConns = 100
transport.MaxIdleConnsPerHost = 10
transport.IdleConnTimeout = 90 * time.Second
transport.DisableCompression = false
transport.ResponseHeaderTimeout = 60 * time.Second
return &LocalAIModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
httpClient: &http.Client{
Transport: transport,
},
}
}
func (l *LocalAIModel) NewInstance(baseURL map[string]string) ModelDriver {
return NewLocalAIModel(baseURL, l.URLSuffix)
}
func (l *LocalAIModel) Name() string {
return "localai"
}
// resolveBaseURL returns the tenant-supplied base URL for the given
// region, falling back to the "default" entry, and fails with a clear
// message when nothing is configured. LocalAI is self-hosted so the
// driver cannot fall back to a public endpoint.
func (l *LocalAIModel) resolveBaseURL(region string) (string, error) {
if base, ok := l.BaseURL[region]; ok && base != "" {
return strings.TrimSuffix(base, "/"), nil
}
if base, ok := l.BaseURL["default"]; ok && base != "" {
return strings.TrimSuffix(base, "/"), nil
}
return "", fmt.Errorf("localai: missing base URL, configure the local access address (e.g., http://127.0.0.1:8080/v1)")
}
// setAuth sets the Authorization header only when a non-empty API key
// is supplied. LocalAI accepts an empty key by default, so sending
// "Bearer " (with an empty value) would be wrong in both directions:
// some local proxies reject it, and it leaks the fact that the
// driver was misconfigured.
func setLocalAIAuth(req *http.Request, apiConfig *APIConfig) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
}
// localAIReasoningFields lists the JSON field names that different
// upstream models put their chain-of-thought into. LocalAI is a proxy
// that can route to any of these, so the driver tries each in turn:
//
// - reasoning_content: OpenAI o-series, kimi-k2.6, DeepSeek-R1,
// magistral when proxied through an OpenAI-shim
// - reasoning: Upstage solar-pro3 (and its proxies)
// - thinking: Qwen3 (Ollama-style) and some local llama-r1
// variants exposed through LocalAI's OpenAI shim
//
// The first non-empty match wins. Order matters: reasoning_content is
// the OpenAI-conformant name and the most widely used, so it's tried
// first.
var localAIReasoningFields = []string{"reasoning_content", "reasoning", "thinking"}
// extractLocalAIReasoning pulls the chain-of-thought out of a message
// or delta object regardless of which field name the upstream model
// chose. Returns "" when no reasoning field is present or non-string.
func extractLocalAIReasoning(m map[string]interface{}) string {
for _, k := range localAIReasoningFields {
if v, ok := m[k].(string); ok && v != "" {
return v
}
}
return ""
}
// addLocalAIReasoningRequestParams propagates the caller's request-side
// reasoning intent into the body. Different upstream models behind
// LocalAI accept different parameters:
//
// - reasoning_effort: OpenAI-compatible reasoning APIs (kimi, magistral,
// solar-pro2/pro3, gpt-o-series, R1 proxies)
// - enable_thinking: Qwen3 explicit thinking toggle
//
// Both are emitted when the caller opts in, so the request works
// against whichever family the LocalAI instance routes to. A non-
// supporting upstream simply ignores the extra field.
func addLocalAIReasoningRequestParams(reqBody map[string]interface{}, cfg *ChatConfig) {
if cfg == nil {
return
}
if cfg.Effort != nil && *cfg.Effort != "" {
reqBody["reasoning_effort"] = *cfg.Effort
}
if cfg.Thinking != nil {
reqBody["enable_thinking"] = *cfg.Thinking
}
}
// ChatWithMessages sends multiple messages with roles and returns the response.
func (l *LocalAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("messages is empty")
}
region := "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := l.resolveBaseURL(region)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat)
apiMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
apiMessages[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": false,
}
// Note: do NOT propagate chatModelConfig.Stream into the request body
// here. ChatWithMessages parses a single JSON response, so stream must
// always be off for this code path.
if chatModelConfig != nil {
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
// LocalAI is a proxy; emit both reasoning_effort and
// enable_thinking so the request works regardless of which
// model family the LocalAI instance routes to. See
// addLocalAIReasoningRequestParams.
addLocalAIReasoningRequestParams(reqBody, chatModelConfig)
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
setLocalAIAuth(req, apiConfig)
resp, err := l.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return nil, fmt.Errorf("no choices in response")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid choice format")
}
messageMap, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid message format")
}
content, ok := messageMap["content"].(string)
if !ok {
return nil, fmt.Errorf("invalid content format")
}
// Pull the chain-of-thought from whichever field the upstream model
// used. See localAIReasoningFields for the priority order.
reasonContent := extractLocalAIReasoning(messageMap)
return &ChatResponse{
Answer: &content,
ReasonContent: &reasonContent,
}, nil
}
// ChatStreamlyWithSender sends messages and streams the response via the
// sender function. The LocalAI SSE stream uses the same shape as OpenAI:
// "data:" lines carrying JSON events, with a final "[DONE]" line.
func (l *LocalAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
if sender == nil {
return fmt.Errorf("sender is required")
}
if len(messages) == 0 {
return fmt.Errorf("messages is empty")
}
region := "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := l.resolveBaseURL(region)
if err != nil {
return err
}
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat)
apiMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
apiMessages[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": true,
}
if chatModelConfig != nil {
// Refuse to run if the caller explicitly asked for stream=false.
// The body of this method only knows how to read SSE, so a
// non-SSE JSON response would be parsed as if it were a stream
// and produce no chunks. Better to fail clearly.
if chatModelConfig.Stream != nil && !*chatModelConfig.Stream {
return fmt.Errorf("stream must be true in ChatStreamlyWithSender")
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
// LocalAI is a proxy; emit both reasoning_effort and
// enable_thinking so the streaming request works regardless of
// which model family the LocalAI instance routes to.
addLocalAIReasoningRequestParams(reqBody, chatModelConfig)
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
// SSE streams are long-lived, so we cannot attach a hard deadline:
// a legitimate response may take many minutes to finish on a busy
// local model. Instead, wrap the request with WithCancel and run
// an idle watchdog below that calls cancel() if no new data has
// arrived for streamIdleTimeout. That bounds the worst-case stall
// to a known finite window without breaking working long streams.
//
// Threading a real caller-supplied ctx through the ModelDriver
// interface remains a wider follow-up; this is the contained fix.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
setLocalAIAuth(req, apiConfig)
resp, err := l.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Idle watchdog: every successful Scan resets lastActive. If
// streamIdleTimeout passes without a reset, the watchdog calls
// cancel(), which closes the underlying connection. The blocking
// scanner.Scan() then returns false with the context error in
// scanner.Err(), and we surface it to the caller instead of
// hanging the goroutine forever.
lastActive := time.Now()
var lastActiveMu sync.Mutex
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(localAIStreamIdleTimeout / 4)
defer ticker.Stop()
for {
select {
case <-done:
return
case now := <-ticker.C:
lastActiveMu.Lock()
idle := now.Sub(lastActive)
lastActiveMu.Unlock()
if idle >= localAIStreamIdleTimeout {
cancel()
return
}
}
}
}()
// SSE parsing: bump the scanner buffer from the 64KB default to 1MB
// so we never silently truncate a long data: line.
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
sawTerminal := false
for scanner.Scan() {
lastActiveMu.Lock()
lastActive = time.Now()
lastActiveMu.Unlock()
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(line[5:])
if data == "[DONE]" {
sawTerminal = true
break
}
var event map[string]interface{}
if err = json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
// Reasoning chunk first, content second. When an SSE event
// carries both, callers that pipe them to a UI render the
// chain-of-thought before the answer for that token, matching
// the wire ordering Upstage solar-pro3 and kimi-k2.6 emit.
// extractLocalAIReasoning tries reasoning_content, reasoning,
// and thinking in that order so this works against whichever
// model family LocalAI routes to.
if reasoning := extractLocalAIReasoning(delta); reasoning != "" {
if err := sender(nil, &reasoning); err != nil {
return err
}
}
content, ok := delta["content"].(string)
if ok && content != "" {
if err := sender(&content, nil); err != nil {
return err
}
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
sawTerminal = true
break
}
}
if err := scanner.Err(); err != nil {
// If the watchdog fired, the context is done; surface that as
// a clearer "idle" error instead of leaking the raw
// "context canceled" string.
if ctx.Err() != nil {
return fmt.Errorf("localai: stream idle for more than %s, aborted", localAIStreamIdleTimeout)
}
return fmt.Errorf("failed to scan response body: %w", err)
}
if !sawTerminal {
return fmt.Errorf("localai: stream ended before [DONE] or finish_reason")
}
endOfStream := "[DONE]"
if err := sender(&endOfStream, nil); err != nil {
return err
}
return nil
}
type localAIEmbeddingData struct {
Embedding []float64 `json:"embedding"`
Object string `json:"object"`
Index int `json:"index"`
}
type localAIEmbeddingResponse struct {
Data []localAIEmbeddingData `json:"data"`
Model string `json:"model"`
Object string `json:"object"`
}
// Embed turns a list of texts into embedding vectors using the LocalAI
// /v1/embeddings endpoint. The output has one vector per input, in the
// same order the inputs were given.
func (l *LocalAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
if len(texts) == 0 {
return []EmbeddingData{}, nil
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := l.resolveBaseURL(region)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Embedding)
reqBody := map[string]interface{}{
"model": *modelName,
"input": texts,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
setLocalAIAuth(req, apiConfig)
resp, err := l.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("LocalAI embeddings API error: %s, body: %s", resp.Status, string(body))
}
var parsed localAIEmbeddingResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// Reorder by the reported index so the output always lines up with
// the input texts, even if the upstream API ever returns items out
// of order. A nil slot at the end indicates the upstream did not
// return an embedding for that input.
embeddings := make([]EmbeddingData, len(texts))
filled := make([]bool, len(texts))
for _, item := range parsed.Data {
if item.Index < 0 || item.Index >= len(texts) {
return nil, fmt.Errorf("localai: response index %d out of range for %d inputs", item.Index, len(texts))
}
if filled[item.Index] {
// A malformed response that repeats the same index would
// silently overwrite the earlier vector. Fail loudly so
// the caller never uses ambiguous output.
return nil, fmt.Errorf("localai: duplicate embedding index %d in response", item.Index)
}
embeddings[item.Index] = EmbeddingData{
Embedding: item.Embedding,
Index: item.Index,
}
filled[item.Index] = true
}
for i, ok := range filled {
if !ok {
return nil, fmt.Errorf("localai: missing embedding for input index %d", i)
}
}
return embeddings, nil
}
type localAIRerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
Documents []string `json:"documents"`
TopN int `json:"top_n"`
}
type localAIRerankResponse struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
} `json:"results"`
}
// Rerank calculates similarity scores between a query and a list of documents
// using LocalAI's /v1/rerank endpoint. The response shape is Cohere-style:
// {results: [{index, relevance_score}]}. The output is copied into the shared
// RerankResponse{Data: []RerankResult{Index, RelevanceScore}} shape that the
// rest of the codebase already consumes.
func (l *LocalAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
if len(documents) == 0 {
return &RerankResponse{}, nil
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := l.resolveBaseURL(region)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Rerank)
topN := len(documents)
if rerankConfig != nil && rerankConfig.TopN > 0 {
topN = rerankConfig.TopN
}
reqBody := localAIRerankRequest{
Model: *modelName,
Query: query,
Documents: documents,
TopN: topN,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
setLocalAIAuth(req, apiConfig)
resp, err := l.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("LocalAI rerank API error: %s, body: %s", resp.Status, string(body))
}
var parsed localAIRerankResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
rerankResponse := &RerankResponse{}
for _, r := range parsed.Results {
if r.Index < 0 || r.Index >= len(documents) {
return nil, fmt.Errorf("localai: rerank result index %d out of range for %d documents", r.Index, len(documents))
}
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
Index: r.Index,
RelevanceScore: r.RelevanceScore,
})
}
return rerankResponse, nil
}
// ListModels returns the list of model ids the running LocalAI instance has
// loaded. There is no fixed model list at the SaaS level because LocalAI is
// self-hosted; the answer depends on what the tenant has configured.
func (l *LocalAIModel) ListModels(apiConfig *APIConfig) ([]string, error) {
region := "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := l.resolveBaseURL(region)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Models)
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
setLocalAIAuth(req, apiConfig)
resp, err := l.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
data, ok := result["data"].([]interface{})
if !ok {
return nil, fmt.Errorf("invalid models list format")
}
models := make([]string, 0)
for _, model := range data {
modelMap, ok := model.(map[string]interface{})
if !ok {
continue
}
modelName, ok := modelMap["id"].(string)
if !ok {
continue
}
models = append(models, modelName)
}
return models, nil
}
// Balance is not exposed by LocalAI (it is self-hosted and free), so this
// returns "no such method".
func (l *LocalAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
return nil, fmt.Errorf("no such method")
}
// CheckConnection runs a lightweight ListModels call to verify the LocalAI
// base URL is reachable.
func (l *LocalAIModel) CheckConnection(apiConfig *APIConfig) error {
_, err := l.ListModels(apiConfig)
if err != nil {
return err
}
return nil
}
// TranscribeAudio (ASR): LocalAI can route audio to a Whisper backend
// when one is loaded, but the wire shape and driver-side plumbing for
// streaming audio uploads is separate from this PR's scope. Stub here
// to satisfy the ModelDriver interface; follow-up PR welcome.
func (l *LocalAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
return nil, fmt.Errorf("%s, no such method", l.Name())
}
func (l *LocalAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s, no such method", l.Name())
}
// AudioSpeech (TTS): same story as TranscribeAudio above.
func (l *LocalAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) {
return nil, fmt.Errorf("%s, no such method", l.Name())
}
func (l *LocalAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s, no such method", l.Name())
}
// OCRFile: LocalAI has no OCR pipeline in its OpenAI-compatible surface;
// document parsing belongs to a different interface entirely.
func (l *LocalAIModel) OCRFile(modelName *string, fileContent *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) {
return nil, fmt.Errorf("%s, no such method", l.Name())
}

View File

@ -0,0 +1,626 @@
package models
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
func newLocalAIForTest(baseURL string) *LocalAIModel {
return NewLocalAIModel(
map[string]string{"default": baseURL},
URLSuffix{
Chat: "chat/completions",
Models: "models",
Embedding: "embeddings",
Rerank: "rerank",
},
)
}
// withLocalAIIdleTimeout swaps the package-level idle timeout for the
// duration of the test. Tests that exercise the stall watchdog use a
// sub-second value so they finish quickly.
func withLocalAIIdleTimeout(t *testing.T, d time.Duration) {
t.Helper()
original := localAIStreamIdleTimeout
localAIStreamIdleTimeout = d
t.Cleanup(func() {
localAIStreamIdleTimeout = original
})
}
func TestLocalAIName(t *testing.T) {
l := newLocalAIForTest("http://unused")
if got := l.Name(); got != "localai" {
t.Errorf("Name()=%q, want %q", got, "localai")
}
}
func TestLocalAIStreamCancelsOnIdle(t *testing.T) {
// The server emits one valid chunk and then stalls. Without the
// watchdog, scanner.Scan() would hang forever. With the watchdog
// at 200ms, it must return a clear "stream idle" error in well
// under a second.
withLocalAIIdleTimeout(t, 200*time.Millisecond)
// hold blocks the handler until the test closes it. Register
// close(hold) FIRST so it runs LAST (defers are LIFO) — wait,
// that's the opposite. We want close(hold) to run BEFORE
// srv.Close() so the handler can return. Use t.Cleanup, which
// runs in reverse-registration order: register srv.Close first
// so it runs last, then close(hold) so it runs first.
hold := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
_, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"hi"}}]}`+"\n")
f.Flush()
}
// Hang until either the client disconnects (watchdog cancels
// the request, which causes r.Context() to fire) or the test
// teardown signals via `hold`.
select {
case <-hold:
case <-r.Context().Done():
}
}))
t.Cleanup(srv.Close)
t.Cleanup(func() { close(hold) })
l := newLocalAIForTest(srv.URL)
var got []string
var mu sync.Mutex
start := time.Now()
err := l.ChatStreamlyWithSender("gpt-4",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil,
func(content *string, _ *string) error {
if content == nil || *content == "" {
return nil
}
mu.Lock()
got = append(got, *content)
mu.Unlock()
return nil
},
)
elapsed := time.Since(start)
if err == nil {
t.Fatal("expected an idle-timeout error, got nil")
}
if !strings.Contains(err.Error(), "idle for more than") {
t.Errorf("expected idle-timeout error, got %v", err)
}
if elapsed > 5*time.Second {
t.Errorf("watchdog did not fire promptly; elapsed=%v", elapsed)
}
mu.Lock()
defer mu.Unlock()
if len(got) == 0 || got[0] != "hi" {
t.Errorf("expected first chunk before stall, got %v", got)
}
}
func TestLocalAIStreamCompletesWithoutTriggeringWatchdog(t *testing.T) {
// Sanity check: a fast, complete stream should not trip the
// watchdog even with a moderately tight idle window.
withLocalAIIdleTimeout(t, 500*time.Millisecond)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
f, _ := w.(http.Flusher)
_, _ = io.WriteString(w,
`data: {"choices":[{"delta":{"content":"a"}}]}`+"\n"+
`data: {"choices":[{"delta":{"content":"b"}}]}`+"\n"+
`data: {"choices":[{"delta":{},"finish_reason":"stop"}]}`+"\n"+
`data: [DONE]`+"\n",
)
if f != nil {
f.Flush()
}
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
var chunks []string
err := l.ChatStreamlyWithSender("gpt-4",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil,
func(content *string, _ *string) error {
if content != nil && *content != "" && *content != "[DONE]" {
chunks = append(chunks, *content)
}
return nil
},
)
if err != nil {
t.Fatalf("stream: %v", err)
}
if strings.Join(chunks, "") != "ab" {
t.Errorf("chunks=%v want [a b]", chunks)
}
}
func TestLocalAIStreamRequiresSender(t *testing.T) {
l := newLocalAIForTest("http://unused")
err := l.ChatStreamlyWithSender("gpt-4",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil, nil)
if err == nil || !strings.Contains(err.Error(), "sender is required") {
t.Errorf("expected sender-required error, got %v", err)
}
}
func TestLocalAIChatMissingBaseURLFailsClearly(t *testing.T) {
// LocalAI has no public default; resolveBaseURL must fail with a
// helpful message when neither the requested region nor "default"
// is configured.
l := NewLocalAIModel(map[string]string{}, URLSuffix{Chat: "chat/completions"})
_, err := l.ChatWithMessages("gpt-4",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "missing base URL") {
t.Errorf("expected missing-base-URL error, got %v", err)
}
}
func TestLocalAIChatOmitsAuthHeaderWhenKeyEmpty(t *testing.T) {
// Optional-auth contract: LocalAI accepts an empty key, so the
// driver must NOT send a "Bearer " header in that case.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "" {
t.Errorf("expected no Authorization header, got %q", got)
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
resp, err := l.ChatWithMessages("gpt-4",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil)
if err != nil {
t.Fatalf("Chat: %v", err)
}
if *resp.Answer != "ok" {
t.Errorf("answer=%q want ok", *resp.Answer)
}
}
func TestLocalAIChatSendsAuthHeaderWhenKeyProvided(t *testing.T) {
// And conversely: when a tenant has put LocalAI behind an auth
// proxy with a token, the driver does send the Bearer header.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer secret" {
t.Errorf("expected Authorization=Bearer secret, got %q", got)
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
key := "secret"
_, err := l.ChatWithMessages("gpt-4",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{ApiKey: &key}, nil)
if err != nil {
t.Fatalf("Chat: %v", err)
}
}
func TestLocalAIBalanceReturnsNoSuchMethod(t *testing.T) {
l := newLocalAIForTest("http://unused")
_, err := l.Balance(&APIConfig{})
if err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Balance: expected 'no such method', got %v", err)
}
}
func TestLocalAIEmbedHappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/embeddings" {
t.Errorf("path=%s", r.URL.Path)
}
_, _ = io.WriteString(w, `{"data":[
{"embedding":[0.1,0.2],"index":0},
{"embedding":[0.3,0.4],"index":1},
{"embedding":[0.5,0.6],"index":2}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
model := "text-embedding-ada-002"
vecs, err := l.Embed(&model, []string{"a", "b", "c"}, &APIConfig{}, nil)
if err != nil {
t.Fatalf("Embed: %v", err)
}
if len(vecs) != 3 {
t.Fatalf("len=%d want 3", len(vecs))
}
if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 {
t.Errorf("vecs[1]=%+v", vecs[1])
}
}
func TestLocalAIEmbedRejectsDuplicateIndex(t *testing.T) {
// CodeRabbit caught that a response repeating data[*].index would
// silently overwrite the earlier vector. Verify the driver fails
// loudly instead.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{"data":[
{"embedding":[1],"index":0},
{"embedding":[2],"index":0}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
model := "text-embedding-ada-002"
_, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") {
t.Errorf("expected duplicate-index error, got %v", err)
}
}
func TestLocalAIEmbedRejectsOutOfRangeIndex(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":7}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
model := "text-embedding-ada-002"
_, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "out of range") {
t.Errorf("expected out-of-range error, got %v", err)
}
}
func TestLocalAIEmbedRejectsMissingSlot(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":0}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
model := "text-embedding-ada-002"
_, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") {
t.Errorf("expected missing-slot error, got %v", err)
}
}
func TestLocalAIEmbedEmptyInputShortCircuits(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("Embed([]) made an unexpected HTTP call")
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
model := "text-embedding-ada-002"
vecs, err := l.Embed(&model, []string{}, &APIConfig{}, nil)
if err != nil || len(vecs) != 0 {
t.Errorf("Embed([])=(%v,%v) want ([],nil)", vecs, err)
}
}
// ---------- reasoning extraction (multi-field) ----------
// Table-driven unit coverage of the helper. Mirrors the priority order
// reasoning_content > reasoning > thinking declared in
// localAIReasoningFields. New upstream field names can be added by
// extending that slice without touching call sites.
func TestExtractLocalAIReasoning(t *testing.T) {
cases := []struct {
name string
in map[string]interface{}
want string
}{
{"empty map", map[string]interface{}{}, ""},
{"reasoning_content wins", map[string]interface{}{
"reasoning_content": "rc",
"reasoning": "r",
"thinking": "t",
}, "rc"},
{"reasoning when no reasoning_content", map[string]interface{}{
"reasoning": "r",
"thinking": "t",
}, "r"},
{"thinking when only that is set", map[string]interface{}{
"thinking": "qwen3-thought",
}, "qwen3-thought"},
{"empty string treated as absent", map[string]interface{}{
"reasoning_content": "",
"reasoning": "fallback",
}, "fallback"},
{"non-string ignored", map[string]interface{}{
"reasoning_content": 42,
"reasoning": "fallback",
}, "fallback"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := extractLocalAIReasoning(tc.in)
if got != tc.want {
t.Errorf("got=%q want=%q", got, tc.want)
}
})
}
}
// Non-streaming chat against an upstream that puts the trace in
// message.reasoning_content (kimi-k2.6, OpenAI o-series, DeepSeek-R1
// when proxied through OpenAI-shim). The driver must surface it on
// ChatResponse.ReasonContent.
func TestLocalAIChatExtractsReasoningContent(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{"choices":[{"message":{
"role":"assistant",
"content":"The answer is 12.",
"reasoning_content":"15% = 0.15; 0.15 * 80 = 12."
}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
resp, err := l.ChatWithMessages("kimi-k2.6",
[]Message{{Role: "user", Content: "15% of 80?"}},
&APIConfig{}, nil)
if err != nil {
t.Fatalf("Chat: %v", err)
}
if *resp.Answer != "The answer is 12." {
t.Errorf("Answer=%q", *resp.Answer)
}
if *resp.ReasonContent != "15% = 0.15; 0.15 * 80 = 12." {
t.Errorf("ReasonContent=%q", *resp.ReasonContent)
}
}
// Non-streaming chat that uses message.thinking (Qwen3 via Ollama-shim
// inside LocalAI). The driver must surface it on ReasonContent too.
func TestLocalAIChatExtractsThinking(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{"choices":[{"message":{
"role":"assistant",
"content":"12",
"thinking":"Compute 15/100 * 80"
}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
resp, err := l.ChatWithMessages("qwen3-32b",
[]Message{{Role: "user", Content: "15% of 80?"}},
&APIConfig{}, nil)
if err != nil {
t.Fatalf("Chat: %v", err)
}
if *resp.ReasonContent != "Compute 15/100 * 80" {
t.Errorf("ReasonContent=%q want %q", *resp.ReasonContent, "Compute 15/100 * 80")
}
}
// Regression net: a response with no reasoning field at all (any
// non-reasoning model) must produce empty ReasonContent without
// crashing or erroring.
func TestLocalAIChatHandlesAbsentReasoning(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{"choices":[{"message":{
"role":"assistant","content":"hello"
}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
resp, err := l.ChatWithMessages("llama-3-8b-instruct",
[]Message{{Role: "user", Content: "hi"}},
&APIConfig{}, nil)
if err != nil {
t.Fatalf("Chat: %v", err)
}
if *resp.Answer != "hello" {
t.Errorf("Answer=%q", *resp.Answer)
}
if *resp.ReasonContent != "" {
t.Errorf("ReasonContent=%q want empty", *resp.ReasonContent)
}
}
// Streaming chat where the upstream interleaves delta.reasoning_content
// chunks and delta.content chunks (kimi-k2.6, o-series shape).
// Reasoning must reach the sender's 2nd arg, content the 1st.
func TestLocalAIStreamExtractsReasoningContentDelta(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w,
`data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+
`data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 1. "}}]}`+"\n"+
`data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 2."}}]}`+"\n"+
`data: {"choices":[{"index":0,"delta":{"content":"Answer."},"finish_reason":"stop"}]}`+"\n"+
`data: [DONE]`+"\n",
)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
var content, reasoning []string
err := l.ChatStreamlyWithSender("kimi-k2.6",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil,
func(c *string, r *string) error {
if c != nil && r != nil {
t.Errorf("sender called with both args non-nil")
}
if r != nil && *r != "" {
reasoning = append(reasoning, *r)
}
if c != nil && *c != "" && *c != "[DONE]" {
content = append(content, *c)
}
return nil
},
)
if err != nil {
t.Fatalf("stream: %v", err)
}
if got := strings.Join(reasoning, ""); got != "step 1. step 2." {
t.Errorf("reasoning joined=%q", got)
}
if got := strings.Join(content, ""); got != "Answer." {
t.Errorf("content joined=%q", got)
}
}
// Streaming chat where the upstream uses delta.thinking (Qwen3 shape).
// The same handler must work.
func TestLocalAIStreamExtractsThinkingDelta(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w,
`data: {"choices":[{"index":0,"delta":{"thinking":"qwen-trace"}}]}`+"\n"+
`data: {"choices":[{"index":0,"delta":{"content":"final"},"finish_reason":"stop"}]}`+"\n"+
`data: [DONE]`+"\n",
)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
var got []string
err := l.ChatStreamlyWithSender("qwen3-32b",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, nil,
func(c *string, r *string) error {
if r != nil && *r != "" {
got = append(got, "R:"+*r)
}
if c != nil && *c != "" && *c != "[DONE]" {
got = append(got, "C:"+*c)
}
return nil
},
)
if err != nil {
t.Fatalf("stream: %v", err)
}
want := []string{"R:qwen-trace", "C:final"}
if len(got) != 2 || got[0] != want[0] || got[1] != want[1] {
t.Errorf("seq=%v want %v", got, want)
}
}
// Request-side: ChatConfig.Effort must flow into request body as
// reasoning_effort.
func TestLocalAIChatPropagatesReasoningEffort(t *testing.T) {
var seen map[string]interface{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read body: %v", err)
return
}
if err := json.Unmarshal(raw, &seen); err != nil {
t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw))
return
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
effort := "high"
_, err := l.ChatWithMessages("kimi-k2.6",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, &ChatConfig{Effort: &effort})
if err != nil {
t.Fatalf("Chat: %v", err)
}
if seen["reasoning_effort"] != "high" {
t.Errorf("reasoning_effort=%v want high", seen["reasoning_effort"])
}
if _, present := seen["enable_thinking"]; present {
t.Errorf("enable_thinking should be absent when Thinking nil")
}
}
// Request-side: ChatConfig.Thinking must flow into request body as
// enable_thinking (Qwen3-style).
func TestLocalAIChatPropagatesEnableThinking(t *testing.T) {
var seen map[string]interface{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read body: %v", err)
return
}
if err := json.Unmarshal(raw, &seen); err != nil {
t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw))
return
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
think := true
_, err := l.ChatWithMessages("qwen3-32b",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, &ChatConfig{Thinking: &think})
if err != nil {
t.Fatalf("Chat: %v", err)
}
if seen["enable_thinking"] != true {
t.Errorf("enable_thinking=%v want true", seen["enable_thinking"])
}
}
// Stream request also propagates the reasoning params.
func TestLocalAIStreamPropagatesReasoningParams(t *testing.T) {
var seen map[string]interface{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read body: %v", err)
return
}
if err := json.Unmarshal(raw, &seen); err != nil {
t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw))
return
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w,
`data: {"choices":[{"index":0,"delta":{"content":"x"},"finish_reason":"stop"}]}`+"\n"+
`data: [DONE]`+"\n",
)
}))
defer srv.Close()
l := newLocalAIForTest(srv.URL)
effort := "medium"
think := true
err := l.ChatStreamlyWithSender("kimi-k2.6",
[]Message{{Role: "user", Content: "x"}},
&APIConfig{}, &ChatConfig{Effort: &effort, Thinking: &think},
func(*string, *string) error { return nil },
)
if err != nil {
t.Fatalf("stream: %v", err)
}
if seen["reasoning_effort"] != "medium" {
t.Errorf("reasoning_effort=%v want medium", seen["reasoning_effort"])
}
if seen["enable_thinking"] != true {
t.Errorf("enable_thinking=%v want true", seen["enable_thinking"])
}
}