mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-01 05:17:51 +08:00
Perf(Go): batch SiliconFlow Encode requests with 32-item chunking (#14719)
### What problem does this PR solve? The SiliconFlow `Encode` method sent one HTTP request per text, which is wasteful and slow when indexing many documents (e.g., 100 docs = 100 round-trips). SiliconFlow's `/v1/embeddings` is OpenAI-compatible and accepts an array of strings in `input` (officially documented at https://docs.siliconflow.cn/en/api-reference/embeddings/create-embeddings, with a documented max array size of 32). This PR batches the requests up to that limit, reducing 100 docs to ~4 round-trips, and replaces `map[string]interface{}` parsing with a typed struct using the same 3-layer validation (count mismatch, out-of-range index, duplicate index) used in the other drivers. ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -19,6 +19,7 @@ package models
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -368,11 +369,24 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName string, messages []M
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings
|
||||
type siliconflowEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// siliconflowMaxBatchSize is the per-request input limit documented at
|
||||
// https://docs.siliconflow.cn/en/api-reference/embeddings/create-embeddings.
|
||||
const siliconflowMaxBatchSize = 32
|
||||
|
||||
func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
@ -386,84 +400,103 @@ func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *
|
||||
apiKey = *apiConfig.ApiKey
|
||||
}
|
||||
|
||||
dimension := 0
|
||||
if embeddingConfig != nil {
|
||||
dimension = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
for i, text := range texts {
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"input": text,
|
||||
for start := 0; start < len(texts); start += siliconflowMaxBatchSize {
|
||||
end := start + siliconflowMaxBatchSize
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
batch := texts[start:end]
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
if err := s.encodeBatch(url, *modelName, apiKey, dimension, batch, embeddings[start:end]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
data, ok := result["data"].([]interface{})
|
||||
if !ok || len(data) == 0 {
|
||||
return nil, fmt.Errorf("no data in response")
|
||||
}
|
||||
|
||||
firstData, ok := data[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid data format")
|
||||
}
|
||||
|
||||
embeddingSlice, ok := firstData["embedding"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid embedding format")
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddingSlice))
|
||||
for j, v := range embeddingSlice {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
embedding[j] = val
|
||||
case float32:
|
||||
embedding[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type")
|
||||
}
|
||||
}
|
||||
|
||||
embeddings[i] = embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (s *SiliconflowModel) encodeBatch(url, modelName, apiKey string, dimension int, batch []string, out [][]float64) error {
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"input": batch,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
if dimension > 0 {
|
||||
reqBody["dimensions"] = dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result siliconflowEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Data) != len(batch) {
|
||||
return fmt.Errorf("expected %d embeddings, got %d", len(batch), len(result.Data))
|
||||
}
|
||||
|
||||
seen := make([]bool, len(batch))
|
||||
for _, item := range result.Data {
|
||||
if item.Index < 0 || item.Index >= len(batch) {
|
||||
return fmt.Errorf("embedding index %d out of range", item.Index)
|
||||
}
|
||||
if seen[item.Index] {
|
||||
return fmt.Errorf("duplicate embedding index %d", item.Index)
|
||||
}
|
||||
if len(item.Embedding) == 0 {
|
||||
return fmt.Errorf("empty embedding at index %d", item.Index)
|
||||
}
|
||||
seen[item.Index] = true
|
||||
out[item.Index] = item.Embedding
|
||||
}
|
||||
|
||||
for i, ok := range seen {
|
||||
if !ok {
|
||||
return fmt.Errorf("missing embedding index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
|
||||
Reference in New Issue
Block a user