Files
ragflow/internal/service/model_bundle.go
Jin Hai 70e9743ef1 RAGFlow go API server (#13240)
# RAGFlow Go Implementation Plan 🚀

This repository tracks the progress of porting RAGFlow to Go. We'll
implement core features and provide performance comparisons between
Python and Go versions.

## Implementation Checklist

- [x] User Management APIs
- [x] Dataset Management Operations
- [x] Retrieval Test
- [x] Chat Management Operations
- [x] Infinity Go SDK

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
2026-03-04 19:17:16 +08:00

174 lines
5.1 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"fmt"
"ragflow/internal/model"
)
// ModelBundle provides a unified interface for various model operations
// Similar to Python's LLMBundle but with a more generic name
type ModelBundle struct {
tenantID string
modelType model.ModelType
modelName string
model interface{} // underlying model instance
}
// NewModelBundle creates a new ModelBundle for the given tenant and model type
// If modelName is empty, uses the default model for the tenant and type
func NewModelBundle(tenantID string, modelType model.ModelType, modelName ...string) (*ModelBundle, error) {
bundle := &ModelBundle{
tenantID: tenantID,
modelType: modelType,
}
// Use provided model name if available
if len(modelName) > 0 && modelName[0] != "" {
bundle.modelName = modelName[0]
}
// Get model instance based on type
provider := NewModelProvider()
switch modelType {
case model.ModelTypeEmbedding:
embeddingModel, err := provider.GetEmbeddingModel(context.Background(), tenantID, bundle.modelName)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
bundle.model = embeddingModel
case model.ModelTypeChat:
chatModel, err := provider.GetChatModel(context.Background(), tenantID, bundle.modelName)
if err != nil {
return nil, fmt.Errorf("failed to get chat model: %w", err)
}
bundle.model = chatModel
case model.ModelTypeRerank:
rerankModel, err := provider.GetRerankModel(context.Background(), tenantID, bundle.modelName)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model: %w", err)
}
bundle.model = rerankModel
default:
return nil, fmt.Errorf("unsupported model type: %s", modelType)
}
return bundle, nil
}
// Encode encodes a list of texts into embeddings
// Returns embeddings and token count (for compatibility with Python interface)
func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
if b.modelType != model.ModelTypeEmbedding {
return nil, 0, fmt.Errorf("model type %s does not support encode", b.modelType)
}
embeddingModel, ok := b.model.(model.EmbeddingModel)
if !ok {
return nil, 0, fmt.Errorf("model is not an embedding model")
}
embeddings, err := embeddingModel.Encode(texts)
if err != nil {
return nil, 0, err
}
// TODO: Calculate actual token count
// For now, return a dummy token count
tokenCount := int64(0)
for _, text := range texts {
tokenCount += int64(len(text) / 4) // rough approximation
}
return embeddings, tokenCount, nil
}
// EncodeQuery encodes a single query string into embedding
// Returns embedding and token count
func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
if b.modelType != model.ModelTypeEmbedding {
return nil, 0, fmt.Errorf("model type %s does not support encode query", b.modelType)
}
embeddingModel, ok := b.model.(model.EmbeddingModel)
if !ok {
return nil, 0, fmt.Errorf("model is not an embedding model")
}
embedding, err := embeddingModel.EncodeQuery(query)
if err != nil {
return nil, 0, err
}
// TODO: Calculate actual token count
tokenCount := int64(len(query) / 4)
return embedding, tokenCount, nil
}
// Chat sends a chat message and returns response
func (b *ModelBundle) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, int64, error) {
if b.modelType != model.ModelTypeChat {
return "", 0, fmt.Errorf("model type %s does not support chat", b.modelType)
}
chatModel, ok := b.model.(model.ChatModel)
if !ok {
return "", 0, fmt.Errorf("model is not a chat model")
}
response, err := chatModel.Chat(system, history, genConf)
if err != nil {
return "", 0, err
}
// TODO: Calculate actual token count
tokenCount := int64(len(response) / 4)
return response, tokenCount, nil
}
// Similarity calculates similarity between query and texts
func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64, error) {
if b.modelType != model.ModelTypeRerank {
return nil, 0, fmt.Errorf("model type %s does not support similarity", b.modelType)
}
rerankModel, ok := b.model.(model.RerankModel)
if !ok {
return nil, 0, fmt.Errorf("model is not a rerank model")
}
similarities, err := rerankModel.Similarity(query, texts)
if err != nil {
return nil, 0, err
}
// TODO: Calculate actual token count
tokenCount := int64(len(query)/4) + int64(len(texts)*10)
return similarities, tokenCount, nil
}
// GetModel returns the underlying model instance
func (b *ModelBundle) GetModel() interface{} {
return b.model
}