Files
ragflow/internal/dao/model_provider.go
Jin Hai 70e9743ef1 RAGFlow go API server (#13240)
# RAGFlow Go Implementation Plan 🚀

This repository tracks the progress of porting RAGFlow to Go. We'll
implement core features and provide performance comparisons between
Python and Go versions.

## Implementation Checklist

- [x] User Management APIs
- [x] Dataset Management Operations
- [x] Retrieval Test
- [x] Chat Management Operations
- [x] Infinity Go SDK

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
2026-03-04 19:17:16 +08:00

124 lines
3.6 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dao
import (
"ragflow/internal/server"
"sync"
)
// ModelProviderDAO provides access to model provider configuration data
type ModelProviderDAO struct{}
var (
modelProviderDAOInstance *ModelProviderDAO
modelProviderDAOOnce sync.Once
)
// NewModelProviderDAO creates a new ModelProviderDAO instance (singleton)
func NewModelProviderDAO() *ModelProviderDAO {
modelProviderDAOOnce.Do(func() {
modelProviderDAOInstance = &ModelProviderDAO{}
})
return modelProviderDAOInstance
}
// GetAllProviders returns all model providers
func (dao *ModelProviderDAO) GetAllProviders() []server.ModelProvider {
return server.GetModelProviders()
}
// GetProviderByName returns the model provider with the given name
func (dao *ModelProviderDAO) GetProviderByName(name string) *server.ModelProvider {
return server.GetModelProviderByName(name)
}
// GetLLMByProviderAndName returns the LLM with the given provider name and model name
func (dao *ModelProviderDAO) GetLLMByProviderAndName(providerName, modelName string) *server.LLM {
return server.GetLLMByProviderAndName(providerName, modelName)
}
// GetLLMsByType returns all LLMs across all providers that match the given model type
func (dao *ModelProviderDAO) GetLLMsByType(modelType string) []server.LLM {
var result []server.LLM
for _, provider := range server.GetModelProviders() {
for _, llm := range provider.LLMs {
if llm.ModelType == modelType {
result = append(result, llm)
}
}
}
return result
}
// GetProvidersByTag returns providers that have the given tag in their tags string
func (dao *ModelProviderDAO) GetProvidersByTag(tag string) []server.ModelProvider {
var result []server.ModelProvider
for _, provider := range server.GetModelProviders() {
if containsTag(provider.Tags, tag) {
result = append(result, provider)
}
}
return result
}
// GetLLMsByProviderAndType returns LLMs for a specific provider that match the given model type
func (dao *ModelProviderDAO) GetLLMsByProviderAndType(providerName, modelType string) []server.LLM {
provider := server.GetModelProviderByName(providerName)
if provider == nil {
return nil
}
var result []server.LLM
for _, llm := range provider.LLMs {
if llm.ModelType == modelType {
result = append(result, llm)
}
}
return result
}
// helper function to check if a comma-separated tag string contains a specific tag
func containsTag(tags, tag string) bool {
// Simple implementation: check substring with boundaries
// Assuming tags are uppercase and comma-separated without spaces
// This may need refinement based on actual tag format
for _, t := range splitTags(tags) {
if t == tag {
return true
}
}
return false
}
func splitTags(tags string) []string {
// Split by comma and trim spaces
var result []string
start := 0
for i, ch := range tags {
if ch == ',' {
if start < i {
result = append(result, tags[start:i])
}
start = i + 1
}
}
if start < len(tags) {
result = append(result, tags[start:])
}
return result
}