mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-16 04:17:49 +08:00
# RAGFlow Go Implementation Plan 🚀 This repository tracks the progress of porting RAGFlow to Go. We'll implement core features and provide performance comparisons between Python and Go versions. ## Implementation Checklist - [x] User Management APIs - [x] Dataset Management Operations - [x] Retrieval Test - [x] Chat Management Operations - [x] Infinity Go SDK --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
174 lines
5.1 KiB
Go
174 lines
5.1 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"ragflow/internal/model"
|
|
)
|
|
|
|
// ModelBundle provides a unified interface for various model operations
|
|
// Similar to Python's LLMBundle but with a more generic name
|
|
type ModelBundle struct {
|
|
tenantID string
|
|
modelType model.ModelType
|
|
modelName string
|
|
model interface{} // underlying model instance
|
|
}
|
|
|
|
// NewModelBundle creates a new ModelBundle for the given tenant and model type
|
|
// If modelName is empty, uses the default model for the tenant and type
|
|
func NewModelBundle(tenantID string, modelType model.ModelType, modelName ...string) (*ModelBundle, error) {
|
|
bundle := &ModelBundle{
|
|
tenantID: tenantID,
|
|
modelType: modelType,
|
|
}
|
|
|
|
// Use provided model name if available
|
|
if len(modelName) > 0 && modelName[0] != "" {
|
|
bundle.modelName = modelName[0]
|
|
}
|
|
|
|
// Get model instance based on type
|
|
provider := NewModelProvider()
|
|
switch modelType {
|
|
case model.ModelTypeEmbedding:
|
|
embeddingModel, err := provider.GetEmbeddingModel(context.Background(), tenantID, bundle.modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
|
}
|
|
bundle.model = embeddingModel
|
|
case model.ModelTypeChat:
|
|
chatModel, err := provider.GetChatModel(context.Background(), tenantID, bundle.modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get chat model: %w", err)
|
|
}
|
|
bundle.model = chatModel
|
|
case model.ModelTypeRerank:
|
|
rerankModel, err := provider.GetRerankModel(context.Background(), tenantID, bundle.modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get rerank model: %w", err)
|
|
}
|
|
bundle.model = rerankModel
|
|
default:
|
|
return nil, fmt.Errorf("unsupported model type: %s", modelType)
|
|
}
|
|
|
|
return bundle, nil
|
|
}
|
|
|
|
// Encode encodes a list of texts into embeddings
|
|
// Returns embeddings and token count (for compatibility with Python interface)
|
|
func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
|
|
if b.modelType != model.ModelTypeEmbedding {
|
|
return nil, 0, fmt.Errorf("model type %s does not support encode", b.modelType)
|
|
}
|
|
|
|
embeddingModel, ok := b.model.(model.EmbeddingModel)
|
|
if !ok {
|
|
return nil, 0, fmt.Errorf("model is not an embedding model")
|
|
}
|
|
|
|
embeddings, err := embeddingModel.Encode(texts)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// TODO: Calculate actual token count
|
|
// For now, return a dummy token count
|
|
tokenCount := int64(0)
|
|
for _, text := range texts {
|
|
tokenCount += int64(len(text) / 4) // rough approximation
|
|
}
|
|
|
|
return embeddings, tokenCount, nil
|
|
}
|
|
|
|
// EncodeQuery encodes a single query string into embedding
|
|
// Returns embedding and token count
|
|
func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
|
|
if b.modelType != model.ModelTypeEmbedding {
|
|
return nil, 0, fmt.Errorf("model type %s does not support encode query", b.modelType)
|
|
}
|
|
|
|
embeddingModel, ok := b.model.(model.EmbeddingModel)
|
|
if !ok {
|
|
return nil, 0, fmt.Errorf("model is not an embedding model")
|
|
}
|
|
|
|
embedding, err := embeddingModel.EncodeQuery(query)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// TODO: Calculate actual token count
|
|
tokenCount := int64(len(query) / 4)
|
|
|
|
return embedding, tokenCount, nil
|
|
}
|
|
|
|
// Chat sends a chat message and returns response
|
|
func (b *ModelBundle) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, int64, error) {
|
|
if b.modelType != model.ModelTypeChat {
|
|
return "", 0, fmt.Errorf("model type %s does not support chat", b.modelType)
|
|
}
|
|
|
|
chatModel, ok := b.model.(model.ChatModel)
|
|
if !ok {
|
|
return "", 0, fmt.Errorf("model is not a chat model")
|
|
}
|
|
|
|
response, err := chatModel.Chat(system, history, genConf)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
// TODO: Calculate actual token count
|
|
tokenCount := int64(len(response) / 4)
|
|
|
|
return response, tokenCount, nil
|
|
}
|
|
|
|
// Similarity calculates similarity between query and texts
|
|
func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64, error) {
|
|
if b.modelType != model.ModelTypeRerank {
|
|
return nil, 0, fmt.Errorf("model type %s does not support similarity", b.modelType)
|
|
}
|
|
|
|
rerankModel, ok := b.model.(model.RerankModel)
|
|
if !ok {
|
|
return nil, 0, fmt.Errorf("model is not a rerank model")
|
|
}
|
|
|
|
similarities, err := rerankModel.Similarity(query, texts)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// TODO: Calculate actual token count
|
|
tokenCount := int64(len(query)/4) + int64(len(texts)*10)
|
|
|
|
return similarities, tokenCount, nil
|
|
}
|
|
|
|
// GetModel returns the underlying model instance
|
|
func (b *ModelBundle) GetModel() interface{} {
|
|
return b.model
|
|
}
|