feat: Added LLM factory initialization functionality and knowledge base related API interfaces (#13472)

### What problem does this PR solve?

feat: Added LLM factory initialization functionality and knowledge base
related API interfaces

refactor(dao): Refactored the TenantLLMDAO query method
feat(handler): Implemented knowledge base related API endpoints
feat(service): Added LLM API key setting functionality
feat(model): Extended the knowledge base model definition
feat(config): Added default user LLM configuration

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
chanx
2026-03-09 15:52:14 +08:00
committed by GitHub
parent d0465ba909
commit 25ace613b0
20 changed files with 2446 additions and 340 deletions

10
.gitignore vendored
View File

@ -220,7 +220,13 @@ uv-aarch64*.tar.gz
uv-aarch64-unknown-linux-gnu.tar.gz
docker/launch_backend_service_windows.sh
# C++ build directories
internal/cpp/build/
internal/cpp/cmake-build-release/
internal/cpp/cmake-build-debug/
# Trae IDE config
.trae/
# Go server build output
bin/
internal/cpp/cmake-build-release/
internal/cpp/cmake-build-debug/

View File

@ -6,6 +6,7 @@ import (
"net/http"
"os"
"os/signal"
"ragflow/internal/init_data"
"ragflow/internal/server"
"ragflow/internal/utility"
"strings"
@ -71,6 +72,13 @@ func main() {
logger.Fatal("Failed to initialize database", zap.Error(err))
}
// Initialize LLM factory data models from configuration file
if err := init_data.InitLLMFactory(); err != nil {
logger.Error("Failed to initialize LLM factory", err)
} else {
logger.Info("LLM factory initialized successfully")
}
// Initialize doc engine
if err := engine.Init(&cfg.DocEngine); err != nil {
logger.Fatal("Failed to initialize doc engine", zap.Error(err))

View File

@ -62,8 +62,13 @@ func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pag
user.nickname,
user.avatar as tenant_avatar
`).
Joins("LEFT JOIN user ON dialog.tenant_id = user.id").
Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1")
Joins("LEFT JOIN user ON dialog.tenant_id = user.id")
if len(tenantIDs) > 0 {
query = query.Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1")
} else {
query = query.Where("dialog.tenant_id = ? AND dialog.status = ?", userID, "1")
}
// Apply keyword filter
if keywords != "" {

View File

@ -19,6 +19,7 @@ package dao
import (
"ragflow/internal/model"
"strings"
"time"
)
// KnowledgebaseDAO knowledge base data access object
@ -29,15 +30,133 @@ func NewKnowledgebaseDAO() *KnowledgebaseDAO {
return &KnowledgebaseDAO{}
}
// ListByTenantIDs list knowledge bases by tenant IDs
func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) {
// Create creates a new knowledge base record
func (dao *KnowledgebaseDAO) Create(kb *model.Knowledgebase) error {
return DB.Create(kb).Error
}
// Update updates a knowledge base record
func (dao *KnowledgebaseDAO) Update(kb *model.Knowledgebase) error {
return DB.Save(kb).Error
}
// UpdateByID updates a knowledge base by ID with the given fields
func (dao *KnowledgebaseDAO) UpdateByID(id string, updates map[string]interface{}) error {
return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Updates(updates).Error
}
// Delete soft deletes a knowledge base by setting status to invalid
func (dao *KnowledgebaseDAO) Delete(id string) error {
return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Update("status", string(model.StatusInvalid)).Error
}
// GetByID retrieves a knowledge base by ID
func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) {
var kb model.Knowledgebase
err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error
if err != nil {
return nil, err
}
return &kb, nil
}
// GetByIDAndTenantID retrieves a knowledge base by ID and tenant ID
func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) {
var kb model.Knowledgebase
err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, string(model.StatusValid)).First(&kb).Error
if err != nil {
return nil, err
}
return &kb, nil
}
// GetByIDs retrieves multiple knowledge bases by IDs
func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
err := DB.Where("id IN ? AND status = ?", ids, string(model.StatusValid)).Find(&kbs).Error
return kbs, err
}
// GetByName retrieves a knowledge base by name and tenant ID
func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*model.Knowledgebase, error) {
var kb model.Knowledgebase
err := DB.Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, string(model.StatusValid)).First(&kb).Error
if err != nil {
return nil, err
}
return &kb, nil
}
// GetByCreatedBy retrieves knowledge bases created by a specific user
func (dao *KnowledgebaseDAO) GetByCreatedBy(createdBy string) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
err := DB.Where("created_by = ? AND status = ?", createdBy, string(model.StatusValid)).Find(&kbs).Error
return kbs, err
}
// Query retrieves knowledge bases with filters
func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
query := DB.Where("status = ?", string(model.StatusValid))
for key, value := range filters {
if value != nil && value != "" {
query = query.Where(key+" = ?", value)
}
}
err := query.Find(&kbs).Error
return kbs, err
}
// QueryOne retrieves a single knowledge base with filters
func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*model.Knowledgebase, error) {
var kb model.Knowledgebase
query := DB.Where("status = ?", string(model.StatusValid))
for key, value := range filters {
if value != nil && value != "" {
query = query.Where(key+" = ?", value)
}
}
err := query.First(&kb).Error
if err != nil {
return nil, err
}
return &kb, nil
}
// Count returns the count of knowledge bases matching the filters
func (dao *KnowledgebaseDAO) Count(filters map[string]interface{}) (int64, error) {
var count int64
query := DB.Model(&model.Knowledgebase{}).Where("status = ?", string(model.StatusValid))
for key, value := range filters {
if value != nil && value != "" {
query = query.Where(key+" = ?", value)
}
}
err := query.Count(&count).Error
return count, err
}
// GetByTenantIDs retrieves knowledge bases by tenant IDs with pagination
// This matches the Python get_by_tenant_ids method
func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, keywords, parserID string) ([]*model.KnowledgebaseListItem, int64, error) {
var kbs []*model.KnowledgebaseListItem
var total int64
query := DB.Model(&model.Knowledgebase{}).
Select(`knowledgebase.id, knowledgebase.avatar, knowledgebase.name,
knowledgebase.language, knowledgebase.description, knowledgebase.tenant_id,
knowledgebase.permission, knowledgebase.doc_num, knowledgebase.token_num,
knowledgebase.chunk_num, knowledgebase.parser_id, knowledgebase.embd_id,
user.nickname, user.avatar as tenant_avatar, knowledgebase.update_time`).
Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id").
Where("(knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?", tenantIDs, "team", userID).
Where("knowledgebase.status = ?", "1")
Where("((knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?) AND knowledgebase.status = ?",
tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid))
if keywords != "" {
query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%")
@ -47,22 +166,287 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string,
query = query.Where("knowledgebase.parser_id = ?", parserID)
}
// Order
if desc {
query = query.Order("knowledgebase." + orderby + " DESC")
} else {
query = query.Order("knowledgebase." + orderby + " ASC")
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if pageNumber > 0 && itemsPerPage > 0 {
offset := (pageNumber - 1) * itemsPerPage
if err := query.Offset(offset).Limit(itemsPerPage).Scan(&kbs).Error; err != nil {
return nil, 0, err
}
} else {
if err := query.Scan(&kbs).Error; err != nil {
return nil, 0, err
}
}
return kbs, total, nil
}
// GetAllByTenantIDs retrieves all permitted knowledge bases by tenant IDs
// This matches the Python get_all_kb_by_tenant_ids method
func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
err := DB.Where(
"(tenant_id IN ? AND permission = ?) OR tenant_id = ?",
tenantIDs, string(model.TenantPermissionTeam), userID,
).Order("create_time ASC").Find(&kbs).Error
return kbs, err
}
// GetDetail retrieves detailed knowledge base information with joined pipeline data
// This matches the Python get_detail method
func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*model.KnowledgebaseDetail, error) {
var detail model.KnowledgebaseDetail
err := DB.Table("knowledgebase").
Select(`knowledgebase.id, knowledgebase.embd_id, knowledgebase.avatar, knowledgebase.name,
knowledgebase.language, knowledgebase.description, knowledgebase.permission,
knowledgebase.doc_num, knowledgebase.token_num, knowledgebase.chunk_num,
knowledgebase.parser_id, knowledgebase.pipeline_id,
user_canvas.title as pipeline_name, user_canvas.avatar as pipeline_avatar,
knowledgebase.parser_config, knowledgebase.pagerank,
knowledgebase.graphrag_task_id, knowledgebase.graphrag_task_finish_at,
knowledgebase.raptor_task_id, knowledgebase.raptor_task_finish_at,
knowledgebase.mindmap_task_id, knowledgebase.mindmap_task_finish_at,
knowledgebase.create_time, knowledgebase.update_time`).
Joins("LEFT JOIN user_canvas ON knowledgebase.pipeline_id = user_canvas.id").
Where("knowledgebase.id = ? AND knowledgebase.status = ?", kbID, string(model.StatusValid)).
Scan(&detail).Error
if err != nil {
return nil, err
}
return &detail, nil
}
// Accessible checks if a knowledge base is accessible by a user
// This matches the Python accessible method
func (dao *KnowledgebaseDAO) Accessible(kbID, userID string) bool {
var count int64
err := DB.Table("knowledgebase").
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
Where("knowledgebase.id = ? AND user_tenant.user_id = ? AND knowledgebase.status = ?",
kbID, userID, string(model.StatusValid)).
Count(&count).Error
if err != nil {
return false
}
return count > 0
}
// Accessible4Deletion checks if a knowledge base can be deleted by a user
// This matches the Python accessible4deletion method
func (dao *KnowledgebaseDAO) Accessible4Deletion(kbID, userID string) bool {
var count int64
err := DB.Model(&model.Knowledgebase{}).
Where("id = ? AND created_by = ? AND status = ?", kbID, userID, string(model.StatusValid)).
Count(&count).Error
if err != nil {
return false
}
return count > 0
}
// DuplicateName generates a unique name by appending parentheses if name already exists
// This matches the Python duplicate_name function behavior
func (dao *KnowledgebaseDAO) DuplicateName(name, tenantID string) string {
var existingNames []string
DB.Model(&model.Knowledgebase{}).
Where("name LIKE ? AND tenant_id = ? AND status = ?", name+"%", tenantID, string(model.StatusValid)).
Pluck("name", &existingNames)
if len(existingNames) == 0 {
return name
}
nameSet := make(map[string]bool)
for _, n := range existingNames {
nameSet[strings.ToLower(n)] = true
}
if !nameSet[strings.ToLower(name)] {
return name
}
for i := 1; ; i++ {
newName := name + " " + strings.Repeat("(", i) + strings.Repeat(")", i)
if !nameSet[strings.ToLower(newName)] {
return newName
}
}
}
// AtomicIncreaseDocNumByID atomically increments the document count
// This matches the Python atomic_increase_doc_num_by_id method
func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error {
now := time.Now().Unix()
nowDate := time.Now()
return DB.Model(&model.Knowledgebase{}).
Where("id = ?", kbID).
Updates(map[string]interface{}{
"doc_num": DB.Raw("doc_num + 1"),
"update_time": now,
"update_date": nowDate,
}).Error
}
// DecreaseDocumentNum decreases document, chunk, and token counts
// This matches the Python decrease_document_num_in_delete method
func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error {
now := time.Now().Unix()
nowDate := time.Now()
return DB.Model(&model.Knowledgebase{}).
Where("id = ?", kbID).
Updates(map[string]interface{}{
"doc_num": DB.Raw("doc_num - ?", docNum),
"chunk_num": DB.Raw("chunk_num - ?", chunkNum),
"token_num": DB.Raw("token_num - ?", tokenNum),
"update_time": now,
"update_date": nowDate,
}).Error
}
// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant
// This matches the Python get_kb_ids method
func (dao *KnowledgebaseDAO) GetKBIDsByTenantID(tenantID string) ([]string, error) {
var kbIDs []string
err := DB.Model(&model.Knowledgebase{}).
Where("tenant_id = ? AND status = ?", tenantID, string(model.StatusValid)).
Pluck("id", &kbIDs).Error
return kbIDs, err
}
// GetAllIDs retrieves all knowledge base IDs
// This matches the Python get_all_ids method
func (dao *KnowledgebaseDAO) GetAllIDs() ([]string, error) {
var kbIDs []string
err := DB.Model(&model.Knowledgebase{}).
Where("status = ?", string(model.StatusValid)).
Pluck("id", &kbIDs).Error
return kbIDs, err
}
// UpdateParserConfig updates the parser configuration with deep merge
// This matches the Python update_parser_config method
func (dao *KnowledgebaseDAO) UpdateParserConfig(id string, config map[string]interface{}) error {
var kb model.Knowledgebase
if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil {
return err
}
mergedConfig := mergeConfig(kb.ParserConfig, config)
return DB.Model(&model.Knowledgebase{}).
Where("id = ?", id).
Update("parser_config", mergedConfig).Error
}
// DeleteFieldMap removes the field_map from parser_config
// This matches the Python delete_field_map method
func (dao *KnowledgebaseDAO) DeleteFieldMap(id string) error {
var kb model.Knowledgebase
if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil {
return err
}
if kb.ParserConfig != nil {
delete(kb.ParserConfig, "field_map")
return DB.Model(&model.Knowledgebase{}).
Where("id = ?", id).
Update("parser_config", kb.ParserConfig).Error
}
return nil
}
// GetFieldMap retrieves field mappings from multiple knowledge bases
// This matches the Python get_field_map method
func (dao *KnowledgebaseDAO) GetFieldMap(ids []string) (map[string]interface{}, error) {
conf := make(map[string]interface{})
kbs, err := dao.GetByIDs(ids)
if err != nil {
return nil, err
}
for _, kb := range kbs {
if kb.ParserConfig != nil {
if fieldMap, ok := kb.ParserConfig["field_map"]; ok {
if fm, ok := fieldMap.(map[string]interface{}); ok {
for k, v := range fm {
conf[k] = v
}
}
}
}
}
return conf, nil
}
// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID with tenant join
// This matches the Python get_kb_by_id method
func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
err := DB.Model(&model.Knowledgebase{}).
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
Where("knowledgebase.id = ? AND user_tenant.user_id = ?", kbID, userID).
Limit(1).
Find(&kbs).Error
return kbs, err
}
// GetKBByNameAndUserID retrieves a knowledge base by name and user ID with tenant join
// This matches the Python get_kb_by_name method
func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
err := DB.Model(&model.Knowledgebase{}).
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
Where("knowledgebase.name = ? AND user_tenant.user_id = ?", kbName, userID).
Limit(1).
Find(&kbs).Error
return kbs, err
}
// GetList retrieves knowledge bases with filtering by ID and name
// This matches the Python get_list method
func (dao *KnowledgebaseDAO) GetList(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, error) {
var kbs []*model.Knowledgebase
var total int64
query := DB.Model(&model.Knowledgebase{}).
Where("((tenant_id IN ? AND permission = ?) OR tenant_id = ?) AND status = ?",
tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid))
if id != "" {
query = query.Where("id = ?", id)
}
if name != "" {
query = query.Where("name = ?", name)
}
if desc {
query = query.Order(orderby + " DESC")
} else {
query = query.Order(orderby + " ASC")
}
// Count
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// Pagination
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
if err := query.Offset(offset).Limit(pageSize).Find(&kbs).Error; err != nil {
if pageNumber > 0 && itemsPerPage > 0 {
offset := (pageNumber - 1) * itemsPerPage
if err := query.Offset(offset).Limit(itemsPerPage).Find(&kbs).Error; err != nil {
return nil, 0, err
}
} else {
@ -74,76 +458,39 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string,
return kbs, total, nil
}
// ListByOwnerIDs list knowledge bases by owner IDs
func (dao *KnowledgebaseDAO) ListByOwnerIDs(ownerIDs []string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) {
var kbs []*model.Knowledgebase
query := DB.Model(&model.Knowledgebase{}).
Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id").
Where("knowledgebase.tenant_id IN ?", ownerIDs).
Where("knowledgebase.status = ?", "1")
if keywords != "" {
query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%")
// mergeConfig performs a deep merge of configuration maps
func mergeConfig(old, new map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for k, v := range old {
result[k] = v
}
if parserID != "" {
query = query.Where("knowledgebase.parser_id = ?", parserID)
}
// Order
if desc {
query = query.Order(orderby + " DESC")
} else {
query = query.Order(orderby + " ASC")
}
if err := query.Find(&kbs).Error; err != nil {
return nil, 0, err
}
total := int64(len(kbs))
// Manual pagination
if page > 0 && pageSize > 0 {
start := (page - 1) * pageSize
end := start + pageSize
if end > int(total) {
end = int(total)
}
if start < end {
kbs = kbs[start:end]
} else {
kbs = []*model.Knowledgebase{}
for k, v := range new {
if existing, ok := result[k]; ok {
if existingMap, ok := existing.(map[string]interface{}); ok {
if newMap, ok := v.(map[string]interface{}); ok {
result[k] = mergeConfig(existingMap, newMap)
continue
}
}
if existingSlice, ok := existing.([]interface{}); ok {
if newSlice, ok := v.([]interface{}); ok {
merged := append(existingSlice, newSlice...)
seen := make(map[interface{}]bool)
unique := make([]interface{}, 0)
for _, item := range merged {
if !seen[item] {
seen[item] = true
unique = append(unique, item)
}
}
result[k] = unique
continue
}
}
}
result[k] = v
}
return kbs, total, nil
}
// GetByID gets knowledge base by ID
func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) {
var kb model.Knowledgebase
err := DB.Where("id = ? AND status = ?", id, "1").First(&kb).Error
if err != nil {
return nil, err
}
return &kb, nil
}
// GetByIDAndTenantID gets knowledge base by ID and tenant ID
func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) {
var kb model.Knowledgebase
err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, "1").First(&kb).Error
if err != nil {
return nil, err
}
return &kb, nil
}
// GetByIDs gets knowledge bases by IDs
func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) {
var kbs []*model.Knowledgebase
err := DB.Where("id IN ? AND status = ?", ids, "1").Find(&kbs).Error
return kbs, err
return result
}

View File

@ -67,3 +67,31 @@ func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*model.LLM, error)
}
return &llm, nil
}
// LLMFactoryDAO LLM factory data access object
type LLMFactoryDAO struct{}
// NewLLMFactoryDAO create LLM factory DAO
func NewLLMFactoryDAO() *LLMFactoryDAO {
return &LLMFactoryDAO{}
}
// GetAllValid gets all valid LLM factories
func (dao *LLMFactoryDAO) GetAllValid() ([]*model.LLMFactories, error) {
var factories []*model.LLMFactories
err := DB.Where("status = ?", "1").Find(&factories).Error
if err != nil {
return nil, err
}
return factories, nil
}
// GetByName gets LLM factory by name
func (dao *LLMFactoryDAO) GetByName(name string) (*model.LLMFactories, error) {
var factory model.LLMFactories
err := DB.Where("name = ?", name).First(&factory).Error
if err != nil {
return nil, err
}
return &factory, nil
}

View File

@ -75,7 +75,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) {
Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id").
Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1").
Scan(&results).Error
return results, err
}
@ -98,3 +98,8 @@ func (dao *TenantDAO) Create(tenant *model.Tenant) error {
func (dao *TenantDAO) Delete(id string) error {
return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error
}
// Update updates a tenant by ID
func (dao *TenantDAO) Update(id string, updates map[string]interface{}) error {
return DB.Model(&model.Tenant{}).Where("id = ?", id).Updates(updates).Error
}

View File

@ -94,21 +94,14 @@ func (dao *TenantLLMDAO) Delete(tenantID, factory, modelName string) error {
}
// GetMyLLMs get tenant LLMs with factory details
func (dao *TenantLLMDAO) GetMyLLMs(tenantID string, includeDetails bool) ([]model.MyLLM, error) {
func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]model.MyLLM, error) {
var myLLMs []model.MyLLM
// Base query
query := DB.Table("tenant_llm tl").
Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status").
err := DB.Table("tenant_llm tl").
Select("tl.id, tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status").
Joins("JOIN llm_factories lf ON tl.llm_factory = lf.name").
Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID)
// Add detailed fields if requested
if includeDetails {
query = query.Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status, tl.api_base, tl.max_tokens")
}
err := query.Find(&myLLMs).Error
Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID).
Find(&myLLMs).Error
if err != nil {
return nil, err
}

View File

@ -18,20 +18,21 @@ package handler
import (
"net/http"
"ragflow/internal/common"
"ragflow/internal/service"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"ragflow/internal/service"
)
// KnowledgebaseHandler knowledge base handler
// KnowledgebaseHandler handles knowledge base HTTP requests
type KnowledgebaseHandler struct {
kbService *service.KnowledgebaseService
userService *service.UserService
}
// NewKnowledgebaseHandler create knowledge base handler
// NewKnowledgebaseHandler creates a new knowledge base handler
func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userService *service.UserService) *KnowledgebaseHandler {
return &KnowledgebaseHandler{
kbService: kbService,
@ -39,35 +40,227 @@ func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userServic
}
}
// ListKbs list knowledge bases
// @Summary List Knowledge Bases
// @Description Get list of knowledge bases with filtering and pagination
// getUserID extracts user ID from authorization header
// It validates the authorization token and returns the user ID
// Parameters:
// - c: gin.Context - the HTTP request context
//
// Returns:
// - string: the user ID
// - common.ErrorCode: the error code
// - error: any error that occurred
func (h *KnowledgebaseHandler) getUserID(c *gin.Context) (string, common.ErrorCode, error) {
token := c.GetHeader("Authorization")
if token == "" {
return "", common.CodeUnauthorized, ErrMissingAuth
}
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
return "", code, err
}
return user.ID, common.CodeSuccess, nil
}
// jsonResponse sends a JSON response with code and message
func jsonResponse(c *gin.Context, code common.ErrorCode, data interface{}, message string) {
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": data,
"message": message,
})
}
// jsonError sends a JSON error response
func jsonError(c *gin.Context, code common.ErrorCode, message string) {
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": nil,
"message": message,
})
}
// HTTPError represents an HTTP error
type HTTPError struct {
Code common.ErrorCode
Message string
}
// Error implements the error interface
func (e *HTTPError) Error() string {
return e.Message
}
var (
// ErrMissingAuth indicates missing authorization header
ErrMissingAuth = &HTTPError{Code: common.CodeUnauthorized, Message: "Missing Authorization header"}
// ErrInvalidToken indicates invalid access token
ErrInvalidToken = &HTTPError{Code: common.CodeUnauthorized, Message: "Invalid access token"}
)
// CreateKB handles the create knowledge base request
// @Summary Create Knowledge Base
// @Description Create a new knowledge base (dataset)
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Param keywords query string false "search keywords"
// @Param page query int false "page number"
// @Param page_size query int false "items per page"
// @Param parser_id query string false "parser ID filter"
// @Param orderby query string false "order by field"
// @Param desc query bool false "descending order"
// @Param request body service.ListKbsRequest true "filter options"
// @Success 200 {object} service.ListKbsResponse
// @Security ApiKeyAuth
// @Param request body service.CreateKBRequest true "knowledge base info"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/create [post]
func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
var req service.CreateKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
result, code, err := h.kbService.CreateKB(&req, userID)
if err != nil {
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// UpdateKB handles the update knowledge base request
// @Summary Update Knowledge Base
// @Description Update an existing knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param request body service.UpdateKBRequest true "knowledge base update info"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/update [post]
func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
var req service.UpdateKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
result, code, err := h.kbService.UpdateKB(&req, userID)
if err != nil {
if strings.Contains(err.Error(), "authorization") {
jsonError(c, common.CodeAuthenticationError, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// UpdateMetadataSetting handles the update metadata setting request
// @Summary Update Metadata Setting
// @Description Update metadata settings for a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param request body service.UpdateMetadataSettingRequest true "metadata setting info"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/update_metadata_setting [post]
func (h *KnowledgebaseHandler) UpdateMetadataSetting(c *gin.Context) {
_, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
var req service.UpdateMetadataSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
result, code, err := h.kbService.UpdateMetadataSetting(&req)
if err != nil {
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// GetDetail handles the get knowledge base detail request
// @Summary Get Knowledge Base Detail
// @Description Get detailed information about a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id query string true "Knowledge Base ID"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/detail [get]
func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Query("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
result, code, err := h.kbService.GetDetail(kbID, userID)
if err != nil {
if strings.Contains(err.Error(), "authorized") {
jsonError(c, common.CodeOperatingError, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// ListKbs handles the list knowledge bases request
// @Summary List Knowledge Bases
// @Description List knowledge bases with pagination and filtering
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param request body service.ListKbsRequest true "list options"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/list [post]
func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
// Parse request body - allow empty body
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
var req service.ListKbsRequest
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
jsonError(c, common.CodeDataError, err.Error())
return
}
}
// Extract parameters from query or request body with defaults
// Get parameters from request or query string
keywords := ""
if req.Keywords != nil {
keywords = *req.Keywords
@ -111,7 +304,7 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
if req.Desc != nil {
desc = *req.Desc
} else if descStr := c.Query("desc"); descStr != "" {
desc = descStr == "true"
desc = strings.ToLower(descStr) == "true"
}
var ownerIDs []string
@ -119,40 +312,327 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
ownerIDs = *req.OwnerIDs
}
// Get access token from Authorization header
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "Missing Authorization header",
})
return
}
// Get user by access token
user, code, err := h.userService.GetUserByToken(token)
result, code, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"code": code,
"message": err.Error(),
})
return
}
userID := user.ID
// List knowledge bases
result, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": result,
"message": "success",
})
jsonResponse(c, common.CodeSuccess, result, "success")
}
// DeleteKB handles the delete knowledge base request
// @Summary Delete Knowledge Base
// @Description Soft delete a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param request body object{kb_id string} true "knowledge base id"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/rm [post]
func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
var req struct {
KBID string `json:"kb_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
code, err = h.kbService.DeleteKB(req.KBID, userID)
if err != nil {
if strings.Contains(err.Error(), "authorization") {
jsonError(c, common.CodeAuthenticationError, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, true, "success")
}
// ListTags handles the list tags request for a knowledge base
// @Summary List Tags
// @Description List tags for a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id path string true "Knowledge Base ID"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/{kb_id}/tags [get]
func (h *KnowledgebaseHandler) ListTags(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Param("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
jsonResponse(c, common.CodeSuccess, []string{}, "success")
}
// ListTagsFromKbs handles the list tags from multiple knowledge bases request
// @Summary List Tags from Knowledge Bases
// @Description List tags from multiple knowledge bases
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_ids query string true "Comma-separated Knowledge Base IDs"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/tags [get]
func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbIDsStr := c.Query("kb_ids")
if kbIDsStr == "" {
jsonError(c, common.CodeDataError, "kb_ids is required")
return
}
kbIDs := strings.Split(kbIDsStr, ",")
for _, kbID := range kbIDs {
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
}
jsonResponse(c, common.CodeSuccess, []string{}, "success")
}
// RemoveTags handles the remove tags request
// @Summary Remove Tags
// @Description Remove tags from a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id path string true "Knowledge Base ID"
// @Param request body object{tags []string} true "tags to remove"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/{kb_id}/rm_tags [post]
func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Param("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
var req struct {
Tags []string `json:"tags" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, true, "success")
}
// RenameTag handles the rename tag request
// @Summary Rename Tag
// @Description Rename a tag in a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id path string true "Knowledge Base ID"
// @Param request body object{from_tag string, to_tag string} true "tag rename info"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/{kb_id}/rename_tag [post]
func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Param("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
var req struct {
FromTag string `json:"from_tag" binding:"required"`
ToTag string `json:"to_tag" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, true, "success")
}
// KnowledgeGraph handles the get knowledge graph request
// @Summary Get Knowledge Graph
// @Description Get knowledge graph for a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id path string true "Knowledge Base ID"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/{kb_id}/knowledge_graph [get]
func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Param("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
result := map[string]interface{}{
"graph": map[string]interface{}{},
"mind_map": map[string]interface{}{},
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// DeleteKnowledgeGraph handles the delete knowledge graph request
// @Summary Delete Knowledge Graph
// @Description Delete knowledge graph for a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id path string true "Knowledge Base ID"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/{kb_id}/knowledge_graph [delete]
func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Param("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
jsonResponse(c, common.CodeSuccess, true, "success")
}
// GetMeta handles the get metadata request
// @Summary Get Metadata
// @Description Get metadata for knowledge bases
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_ids query string true "Comma-separated Knowledge Base IDs"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/get_meta [get]
func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbIDsStr := c.Query("kb_ids")
if kbIDsStr == "" {
jsonError(c, common.CodeDataError, "kb_ids is required")
return
}
kbIDs := strings.Split(kbIDsStr, ",")
for _, kbID := range kbIDs {
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
}
jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success")
}
// GetBasicInfo handles the get basic info request
// @Summary Get Basic Info
// @Description Get basic information for a knowledge base
// @Tags knowledgebase
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param kb_id query string true "Knowledge Base ID"
// @Success 200 {object} map[string]interface{}
// @Router /v1/kb/basic_info [get]
func (h *KnowledgebaseHandler) GetBasicInfo(c *gin.Context) {
userID, code, err := h.getUserID(c)
if err != nil {
jsonError(c, code, err.Error())
return
}
kbID := c.Query("kb_id")
if kbID == "" {
jsonError(c, common.CodeDataError, "kb_id is required")
return
}
if !h.kbService.Accessible(kbID, userID) {
jsonError(c, common.CodeAuthenticationError, "No authorization.")
return
}
jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success")
}

View File

@ -21,6 +21,7 @@ import (
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/service"
)
@ -60,50 +61,112 @@ func NewLLMHandler(llmService *service.LLMService, userService *service.UserServ
// @Success 200 {object} map[string]interface{}
// @Router /v1/llm/my_llms [get]
func (h *LLMHandler) GetMyLLMs(c *gin.Context) {
// Extract token from request
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "Missing Authorization header",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeUnauthorized,
"message": "Unauthorized!",
"data": false,
})
return
}
// Get user by token
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
})
return
}
// Get tenant ID from user
tenantID := user.ID
if tenantID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "User has no tenant ID",
})
return
}
// Parse include_details query parameter
includeDetailsStr := c.DefaultQuery("include_details", "false")
includeDetails := includeDetailsStr == "true"
// Get LLMs for tenant
llms, err := h.llmService.GetMyLLMs(tenantID, includeDetails)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get LLMs",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"data": llms,
"code": common.CodeSuccess,
"message": "success",
"data": llms,
})
}
// SetAPIKey set API key for a LLM factory
// @Summary Set API Key
// @Description Set API key for a LLM factory and test connectivity
// @Tags llm
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param request body service.SetAPIKeyRequest true "API Key configuration"
// @Success 200 {object} map[string]interface{}
// @Router /v1/llm/set_api_key [post]
func (h *LLMHandler) SetAPIKey(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeUnauthorized,
"message": "Unauthorized!",
"data": false,
})
return
}
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
})
return
}
var req service.SetAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "Invalid request: " + err.Error(),
"data": false,
})
return
}
tenantID := user.ID
result, err := h.llmService.SetAPIKey(tenantID, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"message": err.Error(),
"data": false,
})
return
}
if req.Verify {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": result,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": true,
})
}
@ -198,52 +261,43 @@ func (h *LLMHandler) Factories(c *gin.Context) {
// @Success 200 {object} map[string][]service.LLMListItem
// @Router /v1/llm/list [get]
func (h *LLMHandler) ListApp(c *gin.Context) {
// Extract token from request
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "Missing Authorization header",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeUnauthorized,
"message": "Unauthorized!",
"data": false,
})
return
}
// Get user by token
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
})
return
}
// Get tenant ID from user
tenantID := user.ID
if tenantID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "User has no tenant ID",
})
return
}
// Parse model_type query parameter
modelType := c.Query("model_type")
// Get LLM list
llms, err := h.llmService.ListLLMs(tenantID, modelType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": llms,
"code": common.CodeSuccess,
"message": "success",
"data": llms,
})
}

View File

@ -21,6 +21,7 @@ import (
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/service"
)
@ -48,44 +49,49 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service
// @Success 200 {object} map[string]interface{}
// @Router /v1/user/tenant_info [get]
func (h *TenantHandler) TenantInfo(c *gin.Context) {
// Extract token from request
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "Missing Authorization header",
})
return
}
// Get user by token
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"code": code,
"message": err.Error(),
c.JSON(http.StatusOK, gin.H{
"code": common.CodeUnauthorized,
"message": "Unauthorized!",
"data": false,
})
return
}
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
})
return
}
// Get tenant info
tenantInfo, err := h.tenantService.GetTenantInfo(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get tenant information",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": false,
})
return
}
if tenantInfo == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "Tenant not found",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"message": "Tenant not found!",
"data": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": tenantInfo,
"code": common.CodeSuccess,
"message": "success",
"data": tenantInfo,
})
}
@ -99,38 +105,39 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) {
// @Success 200 {object} map[string]interface{}
// @Router /v1/tenant/list [get]
func (h *TenantHandler) TenantList(c *gin.Context) {
// Extract token from request
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "Missing Authorization header",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeUnauthorized,
"message": "Unauthorized!",
"data": false,
})
return
}
// Get user by token
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
})
return
}
// Get tenant list
tenantList, err := h.tenantService.GetTenantList(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "Failed to get tenant list",
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": tenantList,
"code": common.CodeSuccess,
"message": "success",
"data": tenantList,
})
}

View File

@ -522,3 +522,61 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) {
"data": channels,
})
}
// SetTenantInfo update tenant information
// @Summary Set Tenant Info
// @Description Update tenant model configuration
// @Tags users
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param request body service.SetTenantInfoRequest true "tenant info"
// @Success 200 {object} map[string]interface{}
// @Router /v1/user/set_tenant_info [post]
func (h *UserHandler) SetTenantInfo(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeUnauthorized,
"message": "Unauthorized!",
"data": false,
})
return
}
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
})
return
}
var req service.SetTenantInfoRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": err.Error(),
"data": false,
})
return
}
err = h.userService.SetTenantInfo(user.ID, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"message": err.Error(),
"data": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": true,
})
}

View File

@ -0,0 +1,157 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package init_data
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"ragflow/internal/dao"
"ragflow/internal/model"
)
// LLMFactoryConfig represents a single LLM factory configuration
type LLMFactoryConfig struct {
Name string `json:"name"`
Logo string `json:"logo"`
Tags string `json:"tags"`
Status string `json:"status"`
Rank string `json:"rank"`
LLM []LLMConfig `json:"llm"`
}
// LLMConfig represents a single LLM model configuration
type LLMConfig struct {
LLMName string `json:"llm_name"`
Tags string `json:"tags"`
MaxTokens int64 `json:"max_tokens"`
ModelType string `json:"model_type"`
IsTools bool `json:"is_tools"`
}
// LLMFactoriesFile represents the structure of llm_factories.json
type LLMFactoriesFile struct {
FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"`
}
// InitLLMFactory initializes LLM factories and models from JSON file
func InitLLMFactory() error {
configPath := filepath.Join(getProjectBaseDirectory(), "conf", "llm_factories.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read llm_factories.json: %w", err)
}
var fileData LLMFactoriesFile
if err := json.Unmarshal(data, &fileData); err != nil {
return fmt.Errorf("failed to parse llm_factories.json: %w", err)
}
db := dao.DB
for _, factory := range fileData.FactoryLLMInfos {
status := factory.Status
if status == "" {
status = "1"
}
llmFactory := &model.LLMFactories{
Name: factory.Name,
Logo: stringPtr(factory.Logo),
Tags: factory.Tags,
Rank: parseInt64(factory.Rank),
Status: &status,
}
var existingFactory model.LLMFactories
result := db.Where("name = ?", factory.Name).First(&existingFactory)
if result.Error != nil {
if err := db.Create(llmFactory).Error; err != nil {
log.Printf("Failed to create LLM factory %s: %v", factory.Name, err)
continue
}
} else {
if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{
"logo": llmFactory.Logo,
"tags": llmFactory.Tags,
"rank": llmFactory.Rank,
"status": llmFactory.Status,
}).Error; err != nil {
log.Printf("Failed to update LLM factory %s: %v", factory.Name, err)
}
}
for _, llm := range factory.LLM {
llmStatus := "1"
llmModel := &model.LLM{
LLMName: llm.LLMName,
ModelType: llm.ModelType,
FID: factory.Name,
MaxTokens: llm.MaxTokens,
Tags: llm.Tags,
IsTools: llm.IsTools,
Status: &llmStatus,
}
var existingLLM model.LLM
result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM)
if result.Error != nil {
if err := db.Create(llmModel).Error; err != nil {
log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err)
}
} else {
if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{
"model_type": llmModel.ModelType,
"max_tokens": llmModel.MaxTokens,
"tags": llmModel.Tags,
"is_tools": llmModel.IsTools,
"status": llmModel.Status,
}).Error; err != nil {
log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err)
}
}
}
}
log.Println("LLM factories initialized successfully")
return nil
}
func getProjectBaseDirectory() string {
cwd, err := os.Getwd()
if err != nil {
return "."
}
return cwd
}
func stringPtr(s string) *string {
if s == "" {
return nil
}
return &s
}
func parseInt64(s string) int64 {
var result int64
fmt.Sscanf(s, "%d", &result)
return result
}

View File

@ -18,7 +18,84 @@ package model
import "time"
// Knowledgebase knowledge base model
// DatasetNameLimit is the maximum length for dataset name
const DatasetNameLimit = 128
// Status represents the status enum values
type Status string
const (
// StatusValid indicates a valid/active record
StatusValid Status = "1"
// StatusInvalid indicates a deleted/inactive record
StatusInvalid Status = "0"
)
// TenantPermission represents the permission level for tenant access
type TenantPermission string
const (
// TenantPermissionMe indicates only the creator can access
TenantPermissionMe TenantPermission = "me"
// TenantPermissionTeam indicates all team members can access
TenantPermissionTeam TenantPermission = "team"
)
// ParserType represents the document parser type
type ParserType string
const (
ParserTypePresentation ParserType = "presentation"
ParserTypeLaws ParserType = "laws"
ParserTypeManual ParserType = "manual"
ParserTypePaper ParserType = "paper"
ParserTypeResume ParserType = "resume"
ParserTypeBook ParserType = "book"
ParserTypeQA ParserType = "qa"
ParserTypeTable ParserType = "table"
ParserTypeNaive ParserType = "naive"
ParserTypePicture ParserType = "picture"
ParserTypeOne ParserType = "one"
ParserTypeAudio ParserType = "audio"
ParserTypeEmail ParserType = "email"
ParserTypeKG ParserType = "knowledge_graph"
ParserTypeTag ParserType = "tag"
)
// TaskStatus represents the status of a processing task
type TaskStatus string
const (
TaskStatusUnstart TaskStatus = "0"
TaskStatusRunning TaskStatus = "1"
TaskStatusCancel TaskStatus = "2"
TaskStatusDone TaskStatus = "3"
TaskStatusFail TaskStatus = "4"
TaskStatusSchedule TaskStatus = "5"
)
// PipelineTaskType represents the type of pipeline task
type PipelineTaskType string
const (
PipelineTaskTypeParse PipelineTaskType = "Parse"
PipelineTaskTypeDownload PipelineTaskType = "Download"
PipelineTaskTypeRAPTOR PipelineTaskType = "RAPTOR"
PipelineTaskTypeGraphRAG PipelineTaskType = "GraphRAG"
PipelineTaskTypeMindmap PipelineTaskType = "Mindmap"
PipelineTaskTypeMemory PipelineTaskType = "Memory"
)
// FileSource represents the source of a file
type FileSource string
const (
FileSourceLocal FileSource = ""
FileSourceKnowledgebase FileSource = "knowledgebase"
FileSourceS3 FileSource = "s3"
)
// Knowledgebase represents the knowledge base model
type Knowledgebase struct {
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"`
@ -27,7 +104,6 @@ type Knowledgebase struct {
Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"`
Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"`
EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"`
TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"`
Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"`
CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"`
DocNum int64 `gorm:"column:doc_num;default:0;index" json:"doc_num"`
@ -37,7 +113,7 @@ type Knowledgebase struct {
VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3;index" json:"vector_similarity_weight"`
ParserID string `gorm:"column:parser_id;size:32;not null;default:naive;index" json:"parser_id"`
PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"`
ParserConfig JSONMap `gorm:"column:parser_config;type:longtext;not null" json:"parser_config"`
ParserConfig JSONMap `gorm:"column:parser_config;type:json" json:"parser_config"`
Pagerank int64 `gorm:"column:pagerank;default:0" json:"pagerank"`
GraphragTaskID *string `gorm:"column:graphrag_task_id;size:32;index" json:"graphrag_task_id,omitempty"`
GraphragTaskFinishAt *time.Time `gorm:"column:graphrag_task_finish_at" json:"graphrag_task_finish_at,omitempty"`
@ -49,12 +125,118 @@ type Knowledgebase struct {
BaseModel
}
// TableName specify table name
// TableName returns the table name for Knowledgebase model
func (Knowledgebase) TableName() string {
return "knowledgebase"
}
// InvitationCode invitation code model
// ToMap converts Knowledgebase to a map for JSON response
func (kb *Knowledgebase) ToMap() map[string]interface{} {
result := map[string]interface{}{
"id": kb.ID,
"tenant_id": kb.TenantID,
"name": kb.Name,
"embd_id": kb.EmbdID,
"permission": kb.Permission,
"created_by": kb.CreatedBy,
"doc_num": kb.DocNum,
"token_num": kb.TokenNum,
"chunk_num": kb.ChunkNum,
"similarity_threshold": kb.SimilarityThreshold,
"vector_similarity_weight": kb.VectorSimilarityWeight,
"parser_id": kb.ParserID,
"parser_config": kb.ParserConfig,
"pagerank": kb.Pagerank,
"create_time": kb.CreateTime,
}
if kb.Avatar != nil {
result["avatar"] = *kb.Avatar
}
if kb.Language != nil {
result["language"] = *kb.Language
}
if kb.Description != nil {
result["description"] = *kb.Description
}
if kb.PipelineID != nil {
result["pipeline_id"] = *kb.PipelineID
}
if kb.GraphragTaskID != nil {
result["graphrag_task_id"] = *kb.GraphragTaskID
}
if kb.GraphragTaskFinishAt != nil {
result["graphrag_task_finish_at"] = kb.GraphragTaskFinishAt.Format("2006-01-02 15:04:05")
}
if kb.RaptorTaskID != nil {
result["raptor_task_id"] = *kb.RaptorTaskID
}
if kb.RaptorTaskFinishAt != nil {
result["raptor_task_finish_at"] = kb.RaptorTaskFinishAt.Format("2006-01-02 15:04:05")
}
if kb.MindmapTaskID != nil {
result["mindmap_task_id"] = *kb.MindmapTaskID
}
if kb.MindmapTaskFinishAt != nil {
result["mindmap_task_finish_at"] = kb.MindmapTaskFinishAt.Format("2006-01-02 15:04:05")
}
if kb.UpdateTime != nil {
result["update_time"] = *kb.UpdateTime
}
return result
}
// KnowledgebaseDetail represents detailed knowledge base information with joined data
type KnowledgebaseDetail struct {
ID string `json:"id"`
EmbdID string `json:"embd_id"`
Avatar *string `json:"avatar,omitempty"`
Name string `json:"name"`
Language *string `json:"language,omitempty"`
Description *string `json:"description,omitempty"`
Permission string `json:"permission"`
DocNum int64 `json:"doc_num"`
TokenNum int64 `json:"token_num"`
ChunkNum int64 `json:"chunk_num"`
ParserID string `json:"parser_id"`
PipelineID *string `json:"pipeline_id,omitempty"`
PipelineName *string `json:"pipeline_name,omitempty"`
PipelineAvatar *string `json:"pipeline_avatar,omitempty"`
ParserConfig JSONMap `json:"parser_config"`
Pagerank int64 `json:"pagerank"`
GraphragTaskID *string `json:"graphrag_task_id,omitempty"`
GraphragTaskFinishAt *string `json:"graphrag_task_finish_at,omitempty"`
RaptorTaskID *string `json:"raptor_task_id,omitempty"`
RaptorTaskFinishAt *string `json:"raptor_task_finish_at,omitempty"`
MindmapTaskID *string `json:"mindmap_task_id,omitempty"`
MindmapTaskFinishAt *string `json:"mindmap_task_finish_at,omitempty"`
CreateTime *int64 `json:"create_time,omitempty"`
UpdateTime *int64 `json:"update_time,omitempty"`
Size int64 `json:"size"`
Connectors []string `json:"connectors"`
}
// KnowledgebaseListItem represents a knowledge base item in list responses
type KnowledgebaseListItem struct {
ID string `json:"id"`
Avatar *string `json:"avatar,omitempty"`
Name string `json:"name"`
Language *string `json:"language,omitempty"`
Description *string `json:"description,omitempty"`
TenantID string `json:"tenant_id"`
Permission string `json:"permission"`
DocNum int64 `json:"doc_num"`
TokenNum int64 `json:"token_num"`
ChunkNum int64 `json:"chunk_num"`
ParserID string `json:"parser_id"`
EmbdID string `json:"embd_id"`
Nickname string `json:"nickname"`
TenantAvatar *string `json:"tenant_avatar,omitempty"`
UpdateTime *int64 `json:"update_time,omitempty"`
}
// InvitationCode represents the invitation code model
type InvitationCode struct {
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
Code string `gorm:"column:code;size:32;not null;index" json:"code"`
@ -65,7 +247,7 @@ type InvitationCode struct {
BaseModel
}
// TableName specify table name
// TableName returns the table name for InvitationCode model
func (InvitationCode) TableName() string {
return "invitation_code"
}

View File

@ -64,13 +64,14 @@ func (TenantLangfuse) TableName() string {
// MyLLM represents LLM information for a tenant with factory details
type MyLLM struct {
ID string `gorm:"column:id" json:"id"`
LLMFactory string `gorm:"column:llm_factory" json:"llm_factory"`
Logo *string `gorm:"column:logo" json:"logo,omitempty"`
Tags string `gorm:"column:tags" json:"tags"`
ModelType string `gorm:"column:model_type" json:"model_type"`
LLMName string `gorm:"column:llm_name" json:"llm_name"`
UsedTokens int64 `gorm:"column:used_tokens" json:"used_tokens"`
Status string `gorm:"column:status" json:"status"`
APIBase string `gorm:"column:api_base" json:"api_base,omitempty"`
MaxTokens int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"`
Tags *string `gorm:"column:tags" json:"tags"`
ModelType *string `gorm:"column:model_type" json:"model_type"`
LLMName *string `gorm:"column:llm_name" json:"llm_name"`
UsedTokens *int64 `gorm:"column:used_tokens" json:"used_tokens"`
Status *string `gorm:"column:status" json:"status"`
APIBase *string `gorm:"column:api_base" json:"api_base,omitempty"`
MaxTokens *int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"`
}

View File

@ -101,6 +101,8 @@ func (r *Router) Setup(engine *gin.Engine) {
engine.POST("/v1/user/setting", r.userHandler.Setting)
// User change password endpoint
engine.POST("/v1/user/setting/password", r.userHandler.ChangePassword)
// User set tenant info endpoint
engine.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo)
// API v1 route group
v1 := engine.Group("/api/v1")
@ -134,7 +136,25 @@ func (r *Router) Setup(engine *gin.Engine) {
// Knowledge base routes
kb := engine.Group("/v1/kb")
{
kb.POST("/create", r.knowledgebaseHandler.CreateKB)
kb.POST("/update", r.knowledgebaseHandler.UpdateKB)
kb.POST("/update_metadata_setting", r.knowledgebaseHandler.UpdateMetadataSetting)
kb.GET("/detail", r.knowledgebaseHandler.GetDetail)
kb.POST("/list", r.knowledgebaseHandler.ListKbs)
kb.POST("/rm", r.knowledgebaseHandler.DeleteKB)
kb.GET("/tags", r.knowledgebaseHandler.ListTagsFromKbs)
kb.GET("/get_meta", r.knowledgebaseHandler.GetMeta)
kb.GET("/basic_info", r.knowledgebaseHandler.GetBasicInfo)
// KB ID specific routes
kbByID := kb.Group("/:kb_id")
{
kbByID.GET("/tags", r.knowledgebaseHandler.ListTags)
kbByID.POST("/rm_tags", r.knowledgebaseHandler.RemoveTags)
kbByID.POST("/rename_tag", r.knowledgebaseHandler.RenameTag)
kbByID.GET("/knowledge_graph", r.knowledgebaseHandler.KnowledgeGraph)
kbByID.DELETE("/knowledge_graph", r.knowledgebaseHandler.DeleteKnowledgeGraph)
}
}
// Chunk routes
@ -149,6 +169,7 @@ func (r *Router) Setup(engine *gin.Engine) {
llm.GET("/my_llms", r.llmHandler.GetMyLLMs)
llm.GET("/factories", r.llmHandler.Factories)
llm.GET("/list", r.llmHandler.ListApp)
llm.POST("/set_api_key", r.llmHandler.SetAPIKey)
}
// Chat routes

View File

@ -40,6 +40,29 @@ type Config struct {
DocEngine DocEngineConfig `mapstructure:"doc_engine"`
RegisterEnabled int `mapstructure:"register_enabled"`
OAuth map[string]OAuthConfig `mapstructure:"oauth"`
UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"`
}
// UserDefaultLLMConfig user default LLM configuration
type UserDefaultLLMConfig struct {
DefaultModels DefaultModelsConfig `mapstructure:"default_models"`
}
// DefaultModelsConfig default models configuration
type DefaultModelsConfig struct {
ChatModel ModelConfig `mapstructure:"chat_model"`
EmbeddingModel ModelConfig `mapstructure:"embedding_model"`
RerankModel ModelConfig `mapstructure:"rerank_model"`
ASRModel ModelConfig `mapstructure:"asr_model"`
Image2TextModel ModelConfig `mapstructure:"image2text_model"`
}
// ModelConfig model configuration
type ModelConfig struct {
Name string `mapstructure:"name"`
APIKey string `mapstructure:"api_key"`
BaseURL string `mapstructure:"base_url"`
Factory string `mapstructure:"factory"`
}
// OAuthConfig OAuth configuration for a channel
@ -414,6 +437,45 @@ func Init(configPath string) error {
}
}
// Map user_default_llm section to UserDefaultLLMConfig
if v.IsSet("user_default_llm") {
userDefaultLLMConfig := v.Sub("user_default_llm")
if userDefaultLLMConfig != nil {
if defaultModels := userDefaultLLMConfig.Sub("default_models"); defaultModels != nil {
globalConfig.UserDefaultLLM.DefaultModels.ChatModel = ModelConfig{
Name: defaultModels.GetString("chat_model.name"),
APIKey: defaultModels.GetString("chat_model.api_key"),
BaseURL: defaultModels.GetString("chat_model.base_url"),
Factory: defaultModels.GetString("chat_model.factory"),
}
globalConfig.UserDefaultLLM.DefaultModels.EmbeddingModel = ModelConfig{
Name: defaultModels.GetString("embedding_model.name"),
APIKey: defaultModels.GetString("embedding_model.api_key"),
BaseURL: defaultModels.GetString("embedding_model.base_url"),
Factory: defaultModels.GetString("embedding_model.factory"),
}
globalConfig.UserDefaultLLM.DefaultModels.RerankModel = ModelConfig{
Name: defaultModels.GetString("rerank_model.name"),
APIKey: defaultModels.GetString("rerank_model.api_key"),
BaseURL: defaultModels.GetString("rerank_model.base_url"),
Factory: defaultModels.GetString("rerank_model.factory"),
}
globalConfig.UserDefaultLLM.DefaultModels.ASRModel = ModelConfig{
Name: defaultModels.GetString("asr_model.name"),
APIKey: defaultModels.GetString("asr_model.api_key"),
BaseURL: defaultModels.GetString("asr_model.base_url"),
Factory: defaultModels.GetString("asr_model.factory"),
}
globalConfig.UserDefaultLLM.DefaultModels.Image2TextModel = ModelConfig{
Name: defaultModels.GetString("image2text_model.name"),
APIKey: defaultModels.GetString("image2text_model.api_key"),
BaseURL: defaultModels.GetString("image2text_model.base_url"),
Factory: defaultModels.GetString("image2text_model.factory"),
}
}
}
}
return nil
}

View File

@ -82,7 +82,7 @@ func (s *ChatService) ListChats(userID string, status string) (*ListChatsRespons
}
// Enrich with knowledge base names
var chatsWithKBNames []*ChatWithKBNames
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
for _, chat := range chats {
kbNames := s.getKBNames(chat.KBIDs)
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
@ -148,7 +148,7 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
}
// Enrich with knowledge base names
var chatsWithKBNames []*ChatWithKBNames
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
for _, chat := range chats {
kbNames := s.getKBNames(chat.KBIDs)
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{

View File

@ -17,25 +17,76 @@
package service
import (
"errors"
"fmt"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/model"
"ragflow/internal/utility"
"strings"
"time"
"github.com/google/uuid"
)
// KnowledgebaseService knowledge base service
// KnowledgebaseService service class for managing dataset operations
type KnowledgebaseService struct {
kbDAO *dao.KnowledgebaseDAO
userTenantDAO *dao.UserTenantDAO
userDAO *dao.UserDAO
tenantDAO *dao.TenantDAO
connectorDAO *dao.ConnectorDAO
}
// NewKnowledgebaseService create knowledge base service
// NewKnowledgebaseService creates a new knowledge base service
func NewKnowledgebaseService() *KnowledgebaseService {
return &KnowledgebaseService{
kbDAO: dao.NewKnowledgebaseDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
userDAO: dao.NewUserDAO(),
tenantDAO: dao.NewTenantDAO(),
connectorDAO: dao.NewConnectorDAO(),
}
}
// ListKbsRequest list knowledge bases request
// CreateKBRequest represents the request for creating a knowledge base
type CreateKBRequest struct {
Name string `json:"name" binding:"required"`
ParserID *string `json:"parser_id,omitempty"`
Description *string `json:"description,omitempty"`
Language *string `json:"language,omitempty"`
Permission *string `json:"permission,omitempty"`
Avatar *string `json:"avatar,omitempty"`
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
}
// CreateKBResponse represents the response for creating a knowledge base
type CreateKBResponse struct {
KBID string `json:"kb_id"`
}
// UpdateKBRequest represents the request for updating a knowledge base
type UpdateKBRequest struct {
KBID string `json:"kb_id" binding:"required"`
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
ParserID string `json:"parser_id" binding:"required"`
Permission *string `json:"permission,omitempty"`
Language *string `json:"language,omitempty"`
Avatar *string `json:"avatar,omitempty"`
Pagerank *int64 `json:"pagerank,omitempty"`
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
Connectors []string `json:"connectors,omitempty"`
}
// UpdateMetadataSettingRequest represents the request for updating metadata settings
type UpdateMetadataSettingRequest struct {
KBID string `json:"kb_id" binding:"required"`
Metadata map[string]interface{} `json:"metadata" binding:"required"`
EnableMetadata *bool `json:"enable_metadata,omitempty"`
}
// ListKbsRequest represents the request for listing knowledge bases
type ListKbsRequest struct {
Keywords *string `json:"keywords,omitempty"`
Page *int `json:"page,omitempty"`
@ -46,37 +97,461 @@ type ListKbsRequest struct {
OwnerIDs *[]string `json:"owner_ids,omitempty"`
}
// ListKbsResponse list knowledge bases response
// ListKbsResponse represents the response for listing knowledge bases
type ListKbsResponse struct {
KBs []*model.Knowledgebase `json:"kbs"`
Total int64 `json:"total"`
KBs []map[string]interface{} `json:"kbs"`
Total int64 `json:"total"`
}
// ListKbs list knowledge bases
func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, error) {
var kbs []*model.Knowledgebase
// CreateKB creates a new knowledge base
// This matches the Python create endpoint in kb_app.py
func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (*CreateKBResponse, common.ErrorCode, error) {
// Validate name is a string
if !isValidString(req.Name) {
return nil, common.CodeDataError, errors.New("Dataset name must be string.")
}
// Trim and validate name
name := strings.TrimSpace(req.Name)
if name == "" {
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
}
// Check name length (using UTF-8 byte length like Python)
if len(name) > model.DatasetNameLimit {
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
}
// Verify tenant exists
tenant, err := s.tenantDAO.GetByID(tenantID)
if err != nil {
return nil, common.CodeDataError, errors.New("Tenant not found.")
}
// Deduplicate name within tenant
duplicateName := s.kbDAO.DuplicateName(name, tenantID)
// Get parser ID (default to "naive")
parserID := "naive"
if req.ParserID != nil && *req.ParserID != "" {
parserID = *req.ParserID
}
// Get parser config with defaults
parserConfig := getParserConfig(parserID, req.ParserConfig)
parserConfig["llm_id"] = tenant.LLMID
// Generate KB ID
kbID := strings.ReplaceAll(uuid.New().String(), "-", "")
// Create knowledge base model
now := time.Now().Unix()
nowDate := time.Now()
kb := &model.Knowledgebase{
ID: kbID,
Name: duplicateName,
TenantID: tenantID,
CreatedBy: tenantID,
ParserID: parserID,
ParserConfig: parserConfig,
Permission: "me",
EmbdID: "",
}
kb.CreateTime = &now
kb.UpdateTime = &now
kb.CreateDate = &nowDate
kb.UpdateDate = &nowDate
status := string(model.StatusValid)
kb.Status = &status
// Set optional fields
if req.Description != nil {
kb.Description = req.Description
}
if req.Language != nil {
kb.Language = req.Language
}
if req.Permission != nil {
kb.Permission = *req.Permission
}
if req.Avatar != nil {
kb.Avatar = req.Avatar
}
// Create in database
if err := s.kbDAO.Create(kb); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to create knowledge base: %w", err)
}
return &CreateKBResponse{KBID: kbID}, common.CodeSuccess, nil
}
// UpdateKB updates an existing knowledge base
// This matches the Python update endpoint in kb_app.py
func (s *KnowledgebaseService) UpdateKB(req *UpdateKBRequest, userID string) (map[string]interface{}, common.ErrorCode, error) {
// Validate name is a string
if !isValidString(req.Name) {
return nil, common.CodeDataError, errors.New("Dataset name must be string.")
}
// Trim and validate name
name := strings.TrimSpace(req.Name)
if name == "" {
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
}
// Check name length
if len(name) > model.DatasetNameLimit {
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
}
// Check authorization
if !s.kbDAO.Accessible4Deletion(req.KBID, userID) {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
// Verify ownership
kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": req.KBID})
if err != nil || len(kbs) == 0 {
return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
}
// Get existing KB
kb, err := s.kbDAO.GetByID(req.KBID)
if err != nil {
return nil, common.CodeDataError, errors.New("can't find this dataset")
}
// Check for duplicate name
if strings.ToLower(name) != strings.ToLower(kb.Name) {
existingKB, _ := s.kbDAO.GetByName(name, userID)
if existingKB != nil {
return nil, common.CodeDataError, errors.New("duplicated dataset name")
}
}
// Build updates
updates := map[string]interface{}{
"name": name,
"parser_id": req.ParserID,
}
if req.Description != nil {
updates["description"] = *req.Description
}
if req.Permission != nil {
updates["permission"] = *req.Permission
}
if req.Language != nil {
updates["language"] = *req.Language
}
if req.Avatar != nil {
updates["avatar"] = *req.Avatar
}
if req.Pagerank != nil {
updates["pagerank"] = *req.Pagerank
}
if req.ParserConfig != nil {
updates["parser_config"] = req.ParserConfig
}
now := time.Now().Unix()
nowDate := time.Now()
updates["update_time"] = now
updates["update_date"] = nowDate
// Update in database
if err := s.kbDAO.UpdateByID(req.KBID, updates); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to update knowledge base: %w", err)
}
// Get updated KB
updatedKB, err := s.kbDAO.GetByID(req.KBID)
if err != nil {
return nil, common.CodeDataError, errors.New("database error (knowledgebase rename)")
}
result := updatedKB.ToMap()
result["connectors"] = req.Connectors
return result, common.CodeSuccess, nil
}
// UpdateMetadataSetting updates the metadata settings for a knowledge base
func (s *KnowledgebaseService) UpdateMetadataSetting(req *UpdateMetadataSettingRequest) (map[string]interface{}, common.ErrorCode, error) {
kb, err := s.kbDAO.GetByID(req.KBID)
if err != nil {
return nil, common.CodeDataError, errors.New("database error (knowledgebase not found)")
}
parserConfig := kb.ParserConfig
if parserConfig == nil {
parserConfig = make(map[string]interface{})
}
parserConfig["metadata"] = req.Metadata
enableMetadata := true
if req.EnableMetadata != nil {
enableMetadata = *req.EnableMetadata
}
parserConfig["enable_metadata"] = enableMetadata
if err := s.kbDAO.UpdateParserConfig(req.KBID, parserConfig); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to update metadata setting: %w", err)
}
result := kb.ToMap()
result["parser_config"] = parserConfig
return result, common.CodeSuccess, nil
}
// GetDetail retrieves detailed information about a knowledge base
// This matches the Python kb_detail endpoint in kb_app.py
func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*model.KnowledgebaseDetail, common.ErrorCode, error) {
// Check authorization
if !s.kbDAO.Accessible(kbID, userID) {
return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
}
// Get detail
detail, err := s.kbDAO.GetDetail(kbID)
if err != nil {
return nil, common.CodeDataError, errors.New("can't find this dataset")
}
// Set connectors (empty for now)
detail.Connectors = []string{}
return detail, common.CodeSuccess, nil
}
// ListKbs lists knowledge bases with pagination and filtering
// This matches the Python list endpoint in kb_app.py
func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, common.ErrorCode, error) {
var kbs []*model.KnowledgebaseListItem
var total int64
var err error
// If owner IDs are provided, filter by them
if ownerIDs != nil && len(ownerIDs) > 0 {
kbs, total, err = s.kbDAO.ListByOwnerIDs(ownerIDs, page, pageSize, orderby, desc, keywords, parserID)
if len(ownerIDs) > 0 {
// List by owner IDs
kbs, total, err = s.kbDAO.GetByTenantIDs(ownerIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
} else {
// Get tenant IDs by user ID
// Get tenant IDs for user
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return nil, err
return nil, common.CodeServerError, err
}
kbs, total, err = s.kbDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
kbs, total, err = s.kbDAO.GetByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
}
if err != nil {
return nil, err
return nil, common.CodeServerError, err
}
// Convert to map slice
kbMaps := make([]map[string]interface{}, len(kbs))
for i, kb := range kbs {
kbMaps[i] = map[string]interface{}{
"id": kb.ID,
"avatar": kb.Avatar,
"name": kb.Name,
"language": kb.Language,
"description": kb.Description,
"tenant_id": kb.TenantID,
"permission": kb.Permission,
"doc_num": kb.DocNum,
"token_num": kb.TokenNum,
"chunk_num": kb.ChunkNum,
"parser_id": kb.ParserID,
"embd_id": kb.EmbdID,
"nickname": kb.Nickname,
"tenant_avatar": kb.TenantAvatar,
"update_time": kb.UpdateTime,
}
}
return &ListKbsResponse{
KBs: kbs,
KBs: kbMaps,
Total: total,
}, nil
}, common.CodeSuccess, nil
}
// DeleteKB soft deletes a knowledge base
// This matches the Python rm endpoint in kb_app.py
func (s *KnowledgebaseService) DeleteKB(kbID, userID string) (common.ErrorCode, error) {
// Check authorization
if !s.kbDAO.Accessible4Deletion(kbID, userID) {
return common.CodeAuthenticationError, errors.New("No authorization.")
}
// Verify ownership
kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": kbID})
if err != nil || len(kbs) == 0 {
return common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
}
// Soft delete
if err := s.kbDAO.Delete(kbID); err != nil {
return common.CodeServerError, fmt.Errorf("database error (knowledgebase removal): %w", err)
}
return common.CodeSuccess, nil
}
// Accessible checks if a knowledge base is accessible by a user
func (s *KnowledgebaseService) Accessible(kbID, userID string) bool {
return s.kbDAO.Accessible(kbID, userID)
}
// GetByID retrieves a knowledge base by ID
func (s *KnowledgebaseService) GetByID(kbID string) (*model.Knowledgebase, error) {
return s.kbDAO.GetByID(kbID)
}
// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant
func (s *KnowledgebaseService) GetKBIDsByTenantID(tenantID string) ([]string, error) {
return s.kbDAO.GetKBIDsByTenantID(tenantID)
}
// isValidString checks if a value is a non-empty string
func isValidString(v interface{}) bool {
str, ok := v.(string)
return ok && str != ""
}
// getParserConfig returns the parser configuration with defaults
// This matches the Python get_parser_config function
func getParserConfig(parserID string, customConfig map[string]interface{}) map[string]interface{} {
config := map[string]interface{}{
"pages": [][]int{{1, 1000000}},
"table_context_size": 0,
"image_context_size": 0,
}
switch parserID {
case "table":
config["layout_recognize"] = false
config["chunk_token_num"] = 128
config["delimiter"] = "\n!?;。;!?"
config["html4excel"] = false
case "naive":
config["chunk_token_num"] = 128
config["delimiter"] = "\n!?;。;!?"
config["html4excel"] = false
default:
config["raptor"] = map[string]interface{}{
"use_raptor": false,
}
}
// Merge custom config if provided
if customConfig != nil {
config = mergeParserConfig(config, customConfig)
}
return config
}
// mergeParserConfig merges two parser configurations
func mergeParserConfig(base, override map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for k, v := range base {
result[k] = v
}
for k, v := range override {
if existing, ok := result[k]; ok {
if existingMap, ok := existing.(map[string]interface{}); ok {
if newMap, ok := v.(map[string]interface{}); ok {
result[k] = mergeParserConfig(existingMap, newMap)
continue
}
}
}
result[k] = v
}
return result
}
// GenerateUUID generates a UUID string without dashes
func GenerateUUID() string {
return strings.ReplaceAll(uuid.New().String(), "-", "")
}
// GetUserByToken gets user by authorization token
func (s *KnowledgebaseService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) {
userService := NewUserService()
return userService.GetUserByToken(authorization)
}
// GetUserByID gets user by ID
func (s *KnowledgebaseService) GetUserByID(id string) (*model.User, error) {
return s.userDAO.GetByAccessToken(id)
}
// GetTenantIDsByUserID gets tenant IDs for a user
func (s *KnowledgebaseService) GetTenantIDsByUserID(userID string) ([]string, error) {
return s.userTenantDAO.GetTenantIDsByUserID(userID)
}
// GetConnectorsByTenantID gets connectors for a tenant
func (s *KnowledgebaseService) GetConnectorsByTenantID(tenantID string) ([]*dao.ConnectorListItem, error) {
return s.connectorDAO.ListByTenantID(tenantID)
}
// GetKBList retrieves knowledge bases with ID and name filtering
func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, common.ErrorCode, error) {
kbs, total, err := s.kbDAO.GetList(tenantIDs, userID, page, pageSize, orderby, desc, id, name)
if err != nil {
return nil, 0, common.CodeServerError, err
}
return kbs, total, common.CodeSuccess, nil
}
// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID
func (s *KnowledgebaseService) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) {
return s.kbDAO.GetKBByIDAndUserID(kbID, userID)
}
// GetKBByNameAndUserID retrieves a knowledge base by name and user ID
func (s *KnowledgebaseService) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) {
return s.kbDAO.GetKBByNameAndUserID(kbName, userID)
}
// AtomicIncreaseDocNumByID atomically increments the document count
func (s *KnowledgebaseService) AtomicIncreaseDocNumByID(kbID string) error {
return s.kbDAO.AtomicIncreaseDocNumByID(kbID)
}
// DecreaseDocumentNum decreases document, chunk, and token counts
func (s *KnowledgebaseService) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error {
return s.kbDAO.DecreaseDocumentNum(kbID, docNum, chunkNum, tokenNum)
}
// UpdateParserConfig updates the parser configuration
func (s *KnowledgebaseService) UpdateParserConfig(id string, config map[string]interface{}) error {
return s.kbDAO.UpdateParserConfig(id, config)
}
// DeleteFieldMap removes the field_map from parser_config
func (s *KnowledgebaseService) DeleteFieldMap(id string) error {
return s.kbDAO.DeleteFieldMap(id)
}
// GetFieldMap retrieves field mappings from multiple knowledge bases
func (s *KnowledgebaseService) GetFieldMap(ids []string) (map[string]interface{}, error) {
return s.kbDAO.GetFieldMap(ids)
}
// GetAllIDs retrieves all knowledge base IDs
func (s *KnowledgebaseService) GetAllIDs() ([]string, error) {
return s.kbDAO.GetAllIDs()
}
// ExtractAccessToken extracts access token from authorization header
func ExtractAccessToken(authorization, secretKey string) (string, error) {
return utility.ExtractAccessToken(authorization, secretKey)
}

View File

@ -17,11 +17,16 @@
package service
import (
"fmt"
"strconv"
"strings"
"ragflow/internal/dao"
"ragflow/internal/model"
)
var DB = dao.DB
// LLMService LLM service
type LLMService struct {
tenantLLMDAO *dao.TenantLLMDAO
@ -38,6 +43,7 @@ func NewLLMService() *LLMService {
// MyLLMItem represents a single LLM item in the response
type MyLLMItem struct {
ID string `json:"id"`
Type string `json:"type"`
Name string `json:"name"`
UsedToken int64 `json:"used_token"`
@ -46,67 +52,89 @@ type MyLLMItem struct {
MaxTokens int64 `json:"max_tokens,omitempty"`
}
// MyLLMResponse represents the response structure for my LLMs
type MyLLMResponse struct {
// MyLLMFactory represents the response structure for a factory in my LLMs
type MyLLMFactory struct {
Tags string `json:"tags"`
LLM []MyLLMItem `json:"llm"`
}
// GetMyLLMs get my LLMs for a tenant
func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMResponse, error) {
// Get LLM list from database
myLLMs, err := s.tenantLLMDAO.GetMyLLMs(tenantID, includeDetails)
if err != nil {
return nil, err
}
func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMFactory, error) {
result := make(map[string]MyLLMFactory)
// Group by factory
result := make(map[string]MyLLMResponse)
providerDAO := dao.NewModelProviderDAO()
for _, llm := range myLLMs {
// Get or create factory entry
resp, exists := result[llm.LLMFactory]
if !exists {
resp = MyLLMResponse{
Tags: llm.Tags,
LLM: []MyLLMItem{},
if includeDetails {
objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID)
if err != nil {
return nil, err
}
factoryDAO := dao.NewLLMFactoryDAO()
factories, err := factoryDAO.GetAllValid()
if err != nil {
return nil, err
}
factoryTagsMap := make(map[string]string)
for _, f := range factories {
if f.Tags != "" {
factoryTagsMap[f.Name] = f.Tags
}
}
// Create LLM item
item := MyLLMItem{
Type: llm.ModelType,
Name: llm.LLMName,
UsedToken: llm.UsedTokens,
Status: llm.Status,
}
// Add detailed fields if requested
if includeDetails {
item.APIBase = llm.APIBase
item.MaxTokens = llm.MaxTokens
// If APIBase is empty, try to get from model provider configuration
if item.APIBase == "" {
provider := providerDAO.GetProviderByName(llm.LLMFactory)
if provider != nil {
// Determine appropriate API base URL based on model type
switch llm.ModelType {
case "embedding":
if provider.DefaultEmbeddingURL != "" {
item.APIBase = provider.DefaultEmbeddingURL
}
// Add other model types here if needed
// case "chat":
// case "rerank":
// etc.
}
for _, o := range objs {
llmFactory := o.LLMFactory
if _, exists := result[llmFactory]; !exists {
tags := factoryTagsMap[llmFactory]
result[llmFactory] = MyLLMFactory{
Tags: tags,
LLM: []MyLLMItem{},
}
}
item := MyLLMItem{
ID: int64ToString(o.ID),
Type: getStringValue(o.ModelType),
Name: getStringValue(o.LLMName),
UsedToken: o.UsedTokens,
Status: getValidStatus(o.Status),
}
if includeDetails {
item.APIBase = getStringValueDefault(o.APIBase, "")
item.MaxTokens = o.MaxTokens
}
factory := result[llmFactory]
factory.LLM = append(factory.LLM, item)
result[llmFactory] = factory
}
} else {
objs, err := s.tenantLLMDAO.GetMyLLMs(tenantID)
if err != nil {
return nil, err
}
resp.LLM = append(resp.LLM, item)
result[llm.LLMFactory] = resp
for _, o := range objs {
llmFactory := o.LLMFactory
if _, exists := result[llmFactory]; !exists {
result[llmFactory] = MyLLMFactory{
Tags: getStringValue(o.Tags),
LLM: []MyLLMItem{},
}
}
item := MyLLMItem{
ID: o.ID,
Type: getStringValue(o.ModelType),
Name: getStringValue(o.LLMName),
UsedToken: getInt64Value(o.UsedTokens),
Status: getStringValueDefault(o.Status, "1"),
}
factory := result[llmFactory]
factory.LLM = append(factory.LLM, item)
result[llmFactory] = factory
}
}
return result, nil
@ -114,6 +142,7 @@ func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string
// LLMListItem represents a single LLM item in the list response
type LLMListItem struct {
ID string `json:"id"`
LLMName string `json:"llm_name"`
ModelType string `json:"model_type"`
FID string `json:"fid"`
@ -142,37 +171,32 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
"GPUStack": true,
}
// Get tenant LLMs
tenantLLMs, err := s.tenantLLMDAO.ListAllByTenant(tenantID)
objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID)
if err != nil {
return nil, err
}
// Build set of factories with valid API keys
facts := make(map[string]bool)
// Build set of valid LLM names@factories
status := make(map[string]bool)
for _, tl := range tenantLLMs {
if tl.APIKey != nil && *tl.APIKey != "" && tl.Status == "1" {
facts[tl.LLMFactory] = true
tenantLLMMapping := make(map[string]string)
for _, o := range objs {
if o.APIKey != nil && *o.APIKey != "" && getValidStatus(o.Status) == "1" {
facts[o.LLMFactory] = true
}
llmName := ""
if tl.LLMName != nil {
llmName = *tl.LLMName
}
key := llmName + "@" + tl.LLMFactory
if tl.Status == "1" {
llmName := getStringValue(o.LLMName)
key := llmName + "@" + o.LLMFactory
if getValidStatus(o.Status) == "1" {
status[key] = true
}
tenantLLMMapping[key] = int64ToString(o.ID)
}
// Get all valid LLMs
allLLMs, err := s.llmDAO.GetAllValid()
if err != nil {
return nil, err
}
// Filter and build result
llmSet := make(map[string]bool)
result := make(ListLLMsResponse)
@ -183,20 +207,18 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
key := llm.LLMName + "@" + llm.FID
// Check if valid (Builtin factory or in status set)
if llm.FID != "Builtin" && !status[key] {
continue
}
// Filter by model type if specified
if modelType != "" && !strings.Contains(llm.ModelType, modelType) {
continue
}
// Determine availability
available := facts[llm.FID] || selfDeployed[llm.FID] || llm.LLMName == "flag-embedding"
available := facts[llm.FID] || selfDeployed[llm.FID] || strings.ToLower(llm.LLMName) == "flag-embedding"
item := LLMListItem{
ID: tenantLLMMapping[key],
LLMName: llm.LLMName,
ModelType: llm.ModelType,
FID: llm.FID,
@ -207,7 +229,6 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
Tags: llm.Tags,
}
// Add BaseModel fields
if llm.CreateDate != nil {
createDateStr := llm.CreateDate.Format("2006-01-02T15:04:05")
item.CreateDate = &createDateStr
@ -225,36 +246,160 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
llmSet[key] = true
}
// Add tenant LLMs that are not in the global list
for _, tl := range tenantLLMs {
llmName := ""
if tl.LLMName != nil {
llmName = *tl.LLMName
}
key := llmName + "@" + tl.LLMFactory
for _, o := range objs {
llmName := getStringValue(o.LLMName)
key := llmName + "@" + o.LLMFactory
if llmSet[key] {
continue
}
// Filter by model type if specified
modelTypeValue := ""
if tl.ModelType != nil {
modelTypeValue = *tl.ModelType
}
modelTypeValue := getStringValue(o.ModelType)
if modelType != "" && !strings.Contains(modelTypeValue, modelType) {
continue
}
item := LLMListItem{
ID: int64ToString(o.ID),
LLMName: llmName,
ModelType: modelTypeValue,
FID: tl.LLMFactory,
FID: o.LLMFactory,
Available: true,
Status: tl.Status,
Status: getValidStatus(o.Status),
}
result[tl.LLMFactory] = append(result[tl.LLMFactory], item)
result[o.LLMFactory] = append(result[o.LLMFactory], item)
}
return result, nil
}
func getStringValue(s *string) string {
if s == nil {
return ""
}
return *s
}
func getStringValueDefault(s *string, defaultVal string) string {
if s == nil || *s == "" {
return defaultVal
}
return *s
}
func getValidStatus(status string) string {
if status == "" {
return "1"
}
return status
}
func getInt64Value(i *int64) int64 {
if i == nil {
return 0
}
return *i
}
func getInt64ValueDefault(i *int64, defaultVal int64) int64 {
if i == nil || *i == 0 {
return defaultVal
}
return *i
}
func getBoolValue(b *bool) bool {
if b == nil {
return false
}
return *b
}
func int64ToString(n int64) string {
return strconv.FormatInt(n, 10)
}
// SetAPIKeyRequest represents the request for setting API key
type SetAPIKeyRequest struct {
LLMFactory string `json:"llm_factory"`
APIKey string `json:"api_key"`
BaseURL string `json:"base_url"`
SourceFID string `json:"source_fid"`
ModelType string `json:"model_type"`
LLMName string `json:"llm_name"`
Verify bool `json:"verify"`
MaxTokens int64 `json:"max_tokens"`
}
// SetAPIKeyResult represents the result of setting API key
type SetAPIKeyResult struct {
Message string `json:"message"`
Success bool `json:"success"`
}
// SetAPIKey sets API key for a LLM factory
func (s *LLMService) SetAPIKey(tenantID string, req *SetAPIKeyRequest) (*SetAPIKeyResult, error) {
factory := req.LLMFactory
baseURL := req.BaseURL
sourceFactory := req.SourceFID
if sourceFactory == "" {
sourceFactory = factory
}
sourceLLMs, err := s.llmDAO.GetByFactory(sourceFactory)
if err != nil || len(sourceLLMs) == 0 {
msg := "No models configured for " + factory + " (source: " + sourceFactory + ")."
if req.Verify {
return &SetAPIKeyResult{Message: msg, Success: false}, nil
}
return nil, fmt.Errorf(msg)
}
llmConfig := map[string]interface{}{
"api_key": req.APIKey,
"api_base": baseURL,
}
if req.ModelType != "" {
llmConfig["model_type"] = req.ModelType
}
if req.LLMName != "" {
llmConfig["llm_name"] = req.LLMName
}
for _, llm := range sourceLLMs {
maxTokens := llm.MaxTokens
if maxTokens == 0 {
maxTokens = 8192
}
llmConfig["max_tokens"] = maxTokens
existingLLM, _ := s.tenantLLMDAO.GetByTenantFactoryAndModelName(tenantID, factory, llm.LLMName)
if existingLLM != nil {
updates := map[string]interface{}{
"api_key": req.APIKey,
"api_base": baseURL,
"max_tokens": maxTokens,
}
DB.Model(&model.TenantLLM{}).
Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, llm.LLMName).
Updates(updates)
} else {
modelType := llm.ModelType
llmName := llm.LLMName
tenantLLM := &model.TenantLLM{
TenantID: tenantID,
LLMFactory: factory,
ModelType: &modelType,
LLMName: &llmName,
APIKey: &req.APIKey,
APIBase: &baseURL,
MaxTokens: maxTokens,
Status: "1",
}
s.tenantLLMDAO.Create(tenantLLM)
}
}
return &SetAPIKeyResult{Message: "", Success: true}, nil
}

View File

@ -151,15 +151,38 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
user.LastLoginTime = &now_date
tenantName := req.Nickname + "'s Kingdom"
llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name
if llmID == "" {
llmID = ""
}
embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name
if embdID == "" {
embdID = ""
}
asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name
if asrID == "" {
asrID = ""
}
img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name
if img2txtID == "" {
img2txtID = ""
}
rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name
if rerankID == "" {
rerankID = ""
}
tenant := &model.Tenant{
ID: userID,
Name: &tenantName,
LLMID: cfg.Server.Mode,
EmbdID: cfg.Server.Mode,
ASRID: cfg.Server.Mode,
Img2TxtID: cfg.Server.Mode,
RerankID: cfg.Server.Mode,
LLMID: llmID,
EmbdID: embdID,
ASRID: asrID,
Img2TxtID: img2txtID,
RerankID: rerankID,
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
Status: &status,
}
tenant.CreateTime = &now
tenant.UpdateTime = &now
@ -753,3 +776,52 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, err
return channels, common.CodeSuccess, nil
}
// SetTenantInfoRequest represents the request for setting tenant info
type SetTenantInfoRequest struct {
TenantID string `json:"tenant_id"`
ASRID string `json:"asr_id"`
EmbdID string `json:"embd_id"`
Img2TxtID string `json:"img2txt_id"`
LLMID string `json:"llm_id"`
RerankID string `json:"rerank_id"`
TTSID string `json:"tts_id"`
}
// SetTenantInfo updates tenant model configuration
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error {
tenantDAO := dao.NewTenantDAO()
_, err := tenantDAO.GetByID(req.TenantID)
if err != nil {
return fmt.Errorf("tenant not found: %w", err)
}
updates := make(map[string]interface{})
if req.LLMID != "" {
updates["llm_id"] = req.LLMID
}
if req.EmbdID != "" {
updates["embd_id"] = req.EmbdID
}
if req.ASRID != "" {
updates["asr_id"] = req.ASRID
}
if req.Img2TxtID != "" {
updates["img2txt_id"] = req.Img2TxtID
}
if req.RerankID != "" {
updates["rerank_id"] = req.RerankID
}
if req.TTSID != "" {
updates["tts_id"] = req.TTSID
}
if len(updates) > 0 {
if err := tenantDAO.Update(req.TenantID, updates); err != nil {
return fmt.Errorf("failed to update tenant: %w", err)
}
}
return nil
}