Perf(Go): batch SiliconFlow Encode requests with 32-item chunking (#14719)

### What problem does this PR solve?

The SiliconFlow `Encode` method sent one HTTP request per text, which is
wasteful and slow when indexing many documents (e.g., 100 docs = 100
round-trips).

SiliconFlow's `/v1/embeddings` is OpenAI-compatible and accepts an array
of strings in `input` (officially documented at
https://docs.siliconflow.cn/en/api-reference/embeddings/create-embeddings,
with a documented max array size of 32). This PR batches the requests up
to that limit, reducing 100 docs to ~4 round-trips, and replaces
`map[string]interface{}` parsing with a typed struct using the same
3-layer validation (count mismatch, out-of-range index, duplicate index)
used in the other drivers.

### Type of change

- [x] Performance Improvement
This commit is contained in:
Joseff
2026-05-11 00:55:27 -04:00
committed by GitHub
parent 4b96362092
commit 0580c137fa

View File

@ -19,6 +19,7 @@ package models
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -368,11 +369,24 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName string, messages []M
return scanner.Err()
}
// Encode encodes a list of texts into embeddings
type siliconflowEmbeddingResponse struct {
Data []struct {
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
// siliconflowMaxBatchSize is the per-request input limit documented at
// https://docs.siliconflow.cn/en/api-reference/embeddings/create-embeddings.
const siliconflowMaxBatchSize = 32
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
if len(texts) == 0 {
return [][]float64{}, nil
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
@ -386,84 +400,103 @@ func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *
apiKey = *apiConfig.ApiKey
}
dimension := 0
if embeddingConfig != nil {
dimension = embeddingConfig.Dimension
}
embeddings := make([][]float64, len(texts))
for i, text := range texts {
reqBody := map[string]interface{}{
"model": modelName,
"input": text,
for start := 0; start < len(texts); start += siliconflowMaxBatchSize {
end := start + siliconflowMaxBatchSize
if end > len(texts) {
end = len(texts)
}
batch := texts[start:end]
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
if err := s.encodeBatch(url, *modelName, apiKey, dimension, batch, embeddings[start:end]); err != nil {
return nil, err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
data, ok := result["data"].([]interface{})
if !ok || len(data) == 0 {
return nil, fmt.Errorf("no data in response")
}
firstData, ok := data[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid data format")
}
embeddingSlice, ok := firstData["embedding"].([]interface{})
if !ok {
return nil, fmt.Errorf("invalid embedding format")
}
embedding := make([]float64, len(embeddingSlice))
for j, v := range embeddingSlice {
switch val := v.(type) {
case float64:
embedding[j] = val
case float32:
embedding[j] = float64(val)
default:
return nil, fmt.Errorf("unexpected embedding value type")
}
}
embeddings[i] = embedding
}
return embeddings, nil
}
func (s *SiliconflowModel) encodeBatch(url, modelName, apiKey string, dimension int, batch []string, out [][]float64) error {
reqBody := map[string]interface{}{
"model": modelName,
"input": batch,
"encoding_format": "float",
}
if dimension > 0 {
reqBody["dimensions"] = dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
}
var result siliconflowEmbeddingResponse
if err = json.Unmarshal(body, &result); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Data) != len(batch) {
return fmt.Errorf("expected %d embeddings, got %d", len(batch), len(result.Data))
}
seen := make([]bool, len(batch))
for _, item := range result.Data {
if item.Index < 0 || item.Index >= len(batch) {
return fmt.Errorf("embedding index %d out of range", item.Index)
}
if seen[item.Index] {
return fmt.Errorf("duplicate embedding index %d", item.Index)
}
if len(item.Embedding) == 0 {
return fmt.Errorf("empty embedding at index %d", item.Index)
}
seen[item.Index] = true
out[item.Index] = item.Embedding
}
for i, ok := range seen {
if !ok {
return fmt.Errorf("missing embedding index %d", i)
}
}
return nil
}
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {