mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-04-20 02:37:26 +08:00
feat: Added LLM factory initialization functionality and knowledge base related API interfaces (#13472)
### What problem does this PR solve? feat: Added LLM factory initialization functionality and knowledge base related API interfaces refactor(dao): Refactored the TenantLLMDAO query method feat(handler): Implemented knowledge base related API endpoints feat(service): Added LLM API key setting functionality feat(model): Extended the knowledge base model definition feat(config): Added default user LLM configuration ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@ -220,7 +220,13 @@ uv-aarch64*.tar.gz
|
||||
uv-aarch64-unknown-linux-gnu.tar.gz
|
||||
docker/launch_backend_service_windows.sh
|
||||
|
||||
# C++ build directories
|
||||
internal/cpp/build/
|
||||
internal/cpp/cmake-build-release/
|
||||
internal/cpp/cmake-build-debug/
|
||||
|
||||
# Trae IDE config
|
||||
.trae/
|
||||
|
||||
# Go server build output
|
||||
bin/
|
||||
internal/cpp/cmake-build-release/
|
||||
internal/cpp/cmake-build-debug/
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"ragflow/internal/init_data"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
"strings"
|
||||
@ -71,6 +72,13 @@ func main() {
|
||||
logger.Fatal("Failed to initialize database", zap.Error(err))
|
||||
}
|
||||
|
||||
// Initialize LLM factory data models from configuration file
|
||||
if err := init_data.InitLLMFactory(); err != nil {
|
||||
logger.Error("Failed to initialize LLM factory", err)
|
||||
} else {
|
||||
logger.Info("LLM factory initialized successfully")
|
||||
}
|
||||
|
||||
// Initialize doc engine
|
||||
if err := engine.Init(&cfg.DocEngine); err != nil {
|
||||
logger.Fatal("Failed to initialize doc engine", zap.Error(err))
|
||||
|
||||
@ -62,8 +62,13 @@ func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pag
|
||||
user.nickname,
|
||||
user.avatar as tenant_avatar
|
||||
`).
|
||||
Joins("LEFT JOIN user ON dialog.tenant_id = user.id").
|
||||
Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1")
|
||||
Joins("LEFT JOIN user ON dialog.tenant_id = user.id")
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
query = query.Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1")
|
||||
} else {
|
||||
query = query.Where("dialog.tenant_id = ? AND dialog.status = ?", userID, "1")
|
||||
}
|
||||
|
||||
// Apply keyword filter
|
||||
if keywords != "" {
|
||||
|
||||
@ -19,6 +19,7 @@ package dao
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KnowledgebaseDAO knowledge base data access object
|
||||
@ -29,15 +30,133 @@ func NewKnowledgebaseDAO() *KnowledgebaseDAO {
|
||||
return &KnowledgebaseDAO{}
|
||||
}
|
||||
|
||||
// ListByTenantIDs list knowledge bases by tenant IDs
|
||||
func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) {
|
||||
// Create creates a new knowledge base record
|
||||
func (dao *KnowledgebaseDAO) Create(kb *model.Knowledgebase) error {
|
||||
return DB.Create(kb).Error
|
||||
}
|
||||
|
||||
// Update updates a knowledge base record
|
||||
func (dao *KnowledgebaseDAO) Update(kb *model.Knowledgebase) error {
|
||||
return DB.Save(kb).Error
|
||||
}
|
||||
|
||||
// UpdateByID updates a knowledge base by ID with the given fields
|
||||
func (dao *KnowledgebaseDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete soft deletes a knowledge base by setting status to invalid
|
||||
func (dao *KnowledgebaseDAO) Delete(id string) error {
|
||||
return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Update("status", string(model.StatusInvalid)).Error
|
||||
}
|
||||
|
||||
// GetByID retrieves a knowledge base by ID
|
||||
func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// GetByIDAndTenantID retrieves a knowledge base by ID and tenant ID
|
||||
func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, string(model.StatusValid)).First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// GetByIDs retrieves multiple knowledge bases by IDs
|
||||
func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Where("id IN ? AND status = ?", ids, string(model.StatusValid)).Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// GetByName retrieves a knowledge base by name and tenant ID
|
||||
func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, string(model.StatusValid)).First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// GetByCreatedBy retrieves knowledge bases created by a specific user
|
||||
func (dao *KnowledgebaseDAO) GetByCreatedBy(createdBy string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Where("created_by = ? AND status = ?", createdBy, string(model.StatusValid)).Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// Query retrieves knowledge bases with filters
|
||||
func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
query := DB.Where("status = ?", string(model.StatusValid))
|
||||
|
||||
for key, value := range filters {
|
||||
if value != nil && value != "" {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
err := query.Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// QueryOne retrieves a single knowledge base with filters
|
||||
func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
query := DB.Where("status = ?", string(model.StatusValid))
|
||||
|
||||
for key, value := range filters {
|
||||
if value != nil && value != "" {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
err := query.First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// Count returns the count of knowledge bases matching the filters
|
||||
func (dao *KnowledgebaseDAO) Count(filters map[string]interface{}) (int64, error) {
|
||||
var count int64
|
||||
query := DB.Model(&model.Knowledgebase{}).Where("status = ?", string(model.StatusValid))
|
||||
|
||||
for key, value := range filters {
|
||||
if value != nil && value != "" {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByTenantIDs retrieves knowledge bases by tenant IDs with pagination
|
||||
// This matches the Python get_by_tenant_ids method
|
||||
func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, keywords, parserID string) ([]*model.KnowledgebaseListItem, int64, error) {
|
||||
var kbs []*model.KnowledgebaseListItem
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&model.Knowledgebase{}).
|
||||
Select(`knowledgebase.id, knowledgebase.avatar, knowledgebase.name,
|
||||
knowledgebase.language, knowledgebase.description, knowledgebase.tenant_id,
|
||||
knowledgebase.permission, knowledgebase.doc_num, knowledgebase.token_num,
|
||||
knowledgebase.chunk_num, knowledgebase.parser_id, knowledgebase.embd_id,
|
||||
user.nickname, user.avatar as tenant_avatar, knowledgebase.update_time`).
|
||||
Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id").
|
||||
Where("(knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?", tenantIDs, "team", userID).
|
||||
Where("knowledgebase.status = ?", "1")
|
||||
Where("((knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?) AND knowledgebase.status = ?",
|
||||
tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid))
|
||||
|
||||
if keywords != "" {
|
||||
query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%")
|
||||
@ -47,22 +166,287 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string,
|
||||
query = query.Where("knowledgebase.parser_id = ?", parserID)
|
||||
}
|
||||
|
||||
// Order
|
||||
if desc {
|
||||
query = query.Order("knowledgebase." + orderby + " DESC")
|
||||
} else {
|
||||
query = query.Order("knowledgebase." + orderby + " ASC")
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if pageNumber > 0 && itemsPerPage > 0 {
|
||||
offset := (pageNumber - 1) * itemsPerPage
|
||||
if err := query.Offset(offset).Limit(itemsPerPage).Scan(&kbs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
if err := query.Scan(&kbs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return kbs, total, nil
|
||||
}
|
||||
|
||||
// GetAllByTenantIDs retrieves all permitted knowledge bases by tenant IDs
|
||||
// This matches the Python get_all_kb_by_tenant_ids method
|
||||
func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
|
||||
err := DB.Where(
|
||||
"(tenant_id IN ? AND permission = ?) OR tenant_id = ?",
|
||||
tenantIDs, string(model.TenantPermissionTeam), userID,
|
||||
).Order("create_time ASC").Find(&kbs).Error
|
||||
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// GetDetail retrieves detailed knowledge base information with joined pipeline data
|
||||
// This matches the Python get_detail method
|
||||
func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*model.KnowledgebaseDetail, error) {
|
||||
var detail model.KnowledgebaseDetail
|
||||
|
||||
err := DB.Table("knowledgebase").
|
||||
Select(`knowledgebase.id, knowledgebase.embd_id, knowledgebase.avatar, knowledgebase.name,
|
||||
knowledgebase.language, knowledgebase.description, knowledgebase.permission,
|
||||
knowledgebase.doc_num, knowledgebase.token_num, knowledgebase.chunk_num,
|
||||
knowledgebase.parser_id, knowledgebase.pipeline_id,
|
||||
user_canvas.title as pipeline_name, user_canvas.avatar as pipeline_avatar,
|
||||
knowledgebase.parser_config, knowledgebase.pagerank,
|
||||
knowledgebase.graphrag_task_id, knowledgebase.graphrag_task_finish_at,
|
||||
knowledgebase.raptor_task_id, knowledgebase.raptor_task_finish_at,
|
||||
knowledgebase.mindmap_task_id, knowledgebase.mindmap_task_finish_at,
|
||||
knowledgebase.create_time, knowledgebase.update_time`).
|
||||
Joins("LEFT JOIN user_canvas ON knowledgebase.pipeline_id = user_canvas.id").
|
||||
Where("knowledgebase.id = ? AND knowledgebase.status = ?", kbID, string(model.StatusValid)).
|
||||
Scan(&detail).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &detail, nil
|
||||
}
|
||||
|
||||
// Accessible checks if a knowledge base is accessible by a user
|
||||
// This matches the Python accessible method
|
||||
func (dao *KnowledgebaseDAO) Accessible(kbID, userID string) bool {
|
||||
var count int64
|
||||
err := DB.Table("knowledgebase").
|
||||
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
|
||||
Where("knowledgebase.id = ? AND user_tenant.user_id = ? AND knowledgebase.status = ?",
|
||||
kbID, userID, string(model.StatusValid)).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// Accessible4Deletion checks if a knowledge base can be deleted by a user
|
||||
// This matches the Python accessible4deletion method
|
||||
func (dao *KnowledgebaseDAO) Accessible4Deletion(kbID, userID string) bool {
|
||||
var count int64
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("id = ? AND created_by = ? AND status = ?", kbID, userID, string(model.StatusValid)).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// DuplicateName generates a unique name by appending parentheses if name already exists
|
||||
// This matches the Python duplicate_name function behavior
|
||||
func (dao *KnowledgebaseDAO) DuplicateName(name, tenantID string) string {
|
||||
var existingNames []string
|
||||
DB.Model(&model.Knowledgebase{}).
|
||||
Where("name LIKE ? AND tenant_id = ? AND status = ?", name+"%", tenantID, string(model.StatusValid)).
|
||||
Pluck("name", &existingNames)
|
||||
|
||||
if len(existingNames) == 0 {
|
||||
return name
|
||||
}
|
||||
|
||||
nameSet := make(map[string]bool)
|
||||
for _, n := range existingNames {
|
||||
nameSet[strings.ToLower(n)] = true
|
||||
}
|
||||
|
||||
if !nameSet[strings.ToLower(name)] {
|
||||
return name
|
||||
}
|
||||
|
||||
for i := 1; ; i++ {
|
||||
newName := name + " " + strings.Repeat("(", i) + strings.Repeat(")", i)
|
||||
if !nameSet[strings.ToLower(newName)] {
|
||||
return newName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AtomicIncreaseDocNumByID atomically increments the document count
|
||||
// This matches the Python atomic_increase_doc_num_by_id method
|
||||
func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
Where("id = ?", kbID).
|
||||
Updates(map[string]interface{}{
|
||||
"doc_num": DB.Raw("doc_num + 1"),
|
||||
"update_time": now,
|
||||
"update_date": nowDate,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// DecreaseDocumentNum decreases document, chunk, and token counts
|
||||
// This matches the Python decrease_document_num_in_delete method
|
||||
func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
Where("id = ?", kbID).
|
||||
Updates(map[string]interface{}{
|
||||
"doc_num": DB.Raw("doc_num - ?", docNum),
|
||||
"chunk_num": DB.Raw("chunk_num - ?", chunkNum),
|
||||
"token_num": DB.Raw("token_num - ?", tokenNum),
|
||||
"update_time": now,
|
||||
"update_date": nowDate,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant
|
||||
// This matches the Python get_kb_ids method
|
||||
func (dao *KnowledgebaseDAO) GetKBIDsByTenantID(tenantID string) ([]string, error) {
|
||||
var kbIDs []string
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("tenant_id = ? AND status = ?", tenantID, string(model.StatusValid)).
|
||||
Pluck("id", &kbIDs).Error
|
||||
return kbIDs, err
|
||||
}
|
||||
|
||||
// GetAllIDs retrieves all knowledge base IDs
|
||||
// This matches the Python get_all_ids method
|
||||
func (dao *KnowledgebaseDAO) GetAllIDs() ([]string, error) {
|
||||
var kbIDs []string
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("status = ?", string(model.StatusValid)).
|
||||
Pluck("id", &kbIDs).Error
|
||||
return kbIDs, err
|
||||
}
|
||||
|
||||
// UpdateParserConfig updates the parser configuration with deep merge
|
||||
// This matches the Python update_parser_config method
|
||||
func (dao *KnowledgebaseDAO) UpdateParserConfig(id string, config map[string]interface{}) error {
|
||||
var kb model.Knowledgebase
|
||||
if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mergedConfig := mergeConfig(kb.ParserConfig, config)
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
Where("id = ?", id).
|
||||
Update("parser_config", mergedConfig).Error
|
||||
}
|
||||
|
||||
// DeleteFieldMap removes the field_map from parser_config
|
||||
// This matches the Python delete_field_map method
|
||||
func (dao *KnowledgebaseDAO) DeleteFieldMap(id string) error {
|
||||
var kb model.Knowledgebase
|
||||
if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if kb.ParserConfig != nil {
|
||||
delete(kb.ParserConfig, "field_map")
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
Where("id = ?", id).
|
||||
Update("parser_config", kb.ParserConfig).Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFieldMap retrieves field mappings from multiple knowledge bases
|
||||
// This matches the Python get_field_map method
|
||||
func (dao *KnowledgebaseDAO) GetFieldMap(ids []string) (map[string]interface{}, error) {
|
||||
conf := make(map[string]interface{})
|
||||
kbs, err := dao.GetByIDs(ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, kb := range kbs {
|
||||
if kb.ParserConfig != nil {
|
||||
if fieldMap, ok := kb.ParserConfig["field_map"]; ok {
|
||||
if fm, ok := fieldMap.(map[string]interface{}); ok {
|
||||
for k, v := range fm {
|
||||
conf[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID with tenant join
|
||||
// This matches the Python get_kb_by_id method
|
||||
func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
|
||||
Where("knowledgebase.id = ? AND user_tenant.user_id = ?", kbID, userID).
|
||||
Limit(1).
|
||||
Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// GetKBByNameAndUserID retrieves a knowledge base by name and user ID with tenant join
|
||||
// This matches the Python get_kb_by_name method
|
||||
func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
|
||||
Where("knowledgebase.name = ? AND user_tenant.user_id = ?", kbName, userID).
|
||||
Limit(1).
|
||||
Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// GetList retrieves knowledge bases with filtering by ID and name
|
||||
// This matches the Python get_list method
|
||||
func (dao *KnowledgebaseDAO) GetList(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&model.Knowledgebase{}).
|
||||
Where("((tenant_id IN ? AND permission = ?) OR tenant_id = ?) AND status = ?",
|
||||
tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid))
|
||||
|
||||
if id != "" {
|
||||
query = query.Where("id = ?", id)
|
||||
}
|
||||
if name != "" {
|
||||
query = query.Where("name = ?", name)
|
||||
}
|
||||
|
||||
if desc {
|
||||
query = query.Order(orderby + " DESC")
|
||||
} else {
|
||||
query = query.Order(orderby + " ASC")
|
||||
}
|
||||
|
||||
// Count
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Offset(offset).Limit(pageSize).Find(&kbs).Error; err != nil {
|
||||
if pageNumber > 0 && itemsPerPage > 0 {
|
||||
offset := (pageNumber - 1) * itemsPerPage
|
||||
if err := query.Offset(offset).Limit(itemsPerPage).Find(&kbs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
@ -74,76 +458,39 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string,
|
||||
return kbs, total, nil
|
||||
}
|
||||
|
||||
// ListByOwnerIDs list knowledge bases by owner IDs
|
||||
func (dao *KnowledgebaseDAO) ListByOwnerIDs(ownerIDs []string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
|
||||
query := DB.Model(&model.Knowledgebase{}).
|
||||
Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id").
|
||||
Where("knowledgebase.tenant_id IN ?", ownerIDs).
|
||||
Where("knowledgebase.status = ?", "1")
|
||||
|
||||
if keywords != "" {
|
||||
query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%")
|
||||
// mergeConfig performs a deep merge of configuration maps
|
||||
func mergeConfig(old, new map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range old {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
if parserID != "" {
|
||||
query = query.Where("knowledgebase.parser_id = ?", parserID)
|
||||
}
|
||||
|
||||
// Order
|
||||
if desc {
|
||||
query = query.Order(orderby + " DESC")
|
||||
} else {
|
||||
query = query.Order(orderby + " ASC")
|
||||
}
|
||||
|
||||
if err := query.Find(&kbs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
total := int64(len(kbs))
|
||||
|
||||
// Manual pagination
|
||||
if page > 0 && pageSize > 0 {
|
||||
start := (page - 1) * pageSize
|
||||
end := start + pageSize
|
||||
if end > int(total) {
|
||||
end = int(total)
|
||||
}
|
||||
if start < end {
|
||||
kbs = kbs[start:end]
|
||||
} else {
|
||||
kbs = []*model.Knowledgebase{}
|
||||
for k, v := range new {
|
||||
if existing, ok := result[k]; ok {
|
||||
if existingMap, ok := existing.(map[string]interface{}); ok {
|
||||
if newMap, ok := v.(map[string]interface{}); ok {
|
||||
result[k] = mergeConfig(existingMap, newMap)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if existingSlice, ok := existing.([]interface{}); ok {
|
||||
if newSlice, ok := v.([]interface{}); ok {
|
||||
merged := append(existingSlice, newSlice...)
|
||||
seen := make(map[interface{}]bool)
|
||||
unique := make([]interface{}, 0)
|
||||
for _, item := range merged {
|
||||
if !seen[item] {
|
||||
seen[item] = true
|
||||
unique = append(unique, item)
|
||||
}
|
||||
}
|
||||
result[k] = unique
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
return kbs, total, nil
|
||||
}
|
||||
|
||||
// GetByID gets knowledge base by ID
|
||||
func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("id = ? AND status = ?", id, "1").First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// GetByIDAndTenantID gets knowledge base by ID and tenant ID
|
||||
func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, "1").First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// GetByIDs gets knowledge bases by IDs
|
||||
func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Where("id IN ? AND status = ?", ids, "1").Find(&kbs).Error
|
||||
return kbs, err
|
||||
return result
|
||||
}
|
||||
|
||||
@ -67,3 +67,31 @@ func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*model.LLM, error)
|
||||
}
|
||||
return &llm, nil
|
||||
}
|
||||
|
||||
// LLMFactoryDAO LLM factory data access object
|
||||
type LLMFactoryDAO struct{}
|
||||
|
||||
// NewLLMFactoryDAO create LLM factory DAO
|
||||
func NewLLMFactoryDAO() *LLMFactoryDAO {
|
||||
return &LLMFactoryDAO{}
|
||||
}
|
||||
|
||||
// GetAllValid gets all valid LLM factories
|
||||
func (dao *LLMFactoryDAO) GetAllValid() ([]*model.LLMFactories, error) {
|
||||
var factories []*model.LLMFactories
|
||||
err := DB.Where("status = ?", "1").Find(&factories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return factories, nil
|
||||
}
|
||||
|
||||
// GetByName gets LLM factory by name
|
||||
func (dao *LLMFactoryDAO) GetByName(name string) (*model.LLMFactories, error) {
|
||||
var factory model.LLMFactories
|
||||
err := DB.Where("name = ?", name).First(&factory).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &factory, nil
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) {
|
||||
Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id").
|
||||
Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1").
|
||||
Scan(&results).Error
|
||||
|
||||
|
||||
return results, err
|
||||
}
|
||||
|
||||
@ -98,3 +98,8 @@ func (dao *TenantDAO) Create(tenant *model.Tenant) error {
|
||||
func (dao *TenantDAO) Delete(id string) error {
|
||||
return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error
|
||||
}
|
||||
|
||||
// Update updates a tenant by ID
|
||||
func (dao *TenantDAO) Update(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.Tenant{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
@ -94,21 +94,14 @@ func (dao *TenantLLMDAO) Delete(tenantID, factory, modelName string) error {
|
||||
}
|
||||
|
||||
// GetMyLLMs get tenant LLMs with factory details
|
||||
func (dao *TenantLLMDAO) GetMyLLMs(tenantID string, includeDetails bool) ([]model.MyLLM, error) {
|
||||
func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]model.MyLLM, error) {
|
||||
var myLLMs []model.MyLLM
|
||||
|
||||
// Base query
|
||||
query := DB.Table("tenant_llm tl").
|
||||
Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status").
|
||||
err := DB.Table("tenant_llm tl").
|
||||
Select("tl.id, tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status").
|
||||
Joins("JOIN llm_factories lf ON tl.llm_factory = lf.name").
|
||||
Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID)
|
||||
|
||||
// Add detailed fields if requested
|
||||
if includeDetails {
|
||||
query = query.Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status, tl.api_base, tl.max_tokens")
|
||||
}
|
||||
|
||||
err := query.Find(&myLLMs).Error
|
||||
Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID).
|
||||
Find(&myLLMs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -18,20 +18,21 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// KnowledgebaseHandler knowledge base handler
|
||||
// KnowledgebaseHandler handles knowledge base HTTP requests
|
||||
type KnowledgebaseHandler struct {
|
||||
kbService *service.KnowledgebaseService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewKnowledgebaseHandler create knowledge base handler
|
||||
// NewKnowledgebaseHandler creates a new knowledge base handler
|
||||
func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userService *service.UserService) *KnowledgebaseHandler {
|
||||
return &KnowledgebaseHandler{
|
||||
kbService: kbService,
|
||||
@ -39,35 +40,227 @@ func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userServic
|
||||
}
|
||||
}
|
||||
|
||||
// ListKbs list knowledge bases
|
||||
// @Summary List Knowledge Bases
|
||||
// @Description Get list of knowledge bases with filtering and pagination
|
||||
// getUserID extracts user ID from authorization header
|
||||
// It validates the authorization token and returns the user ID
|
||||
// Parameters:
|
||||
// - c: gin.Context - the HTTP request context
|
||||
//
|
||||
// Returns:
|
||||
// - string: the user ID
|
||||
// - common.ErrorCode: the error code
|
||||
// - error: any error that occurred
|
||||
func (h *KnowledgebaseHandler) getUserID(c *gin.Context) (string, common.ErrorCode, error) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
return "", common.CodeUnauthorized, ErrMissingAuth
|
||||
}
|
||||
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
return "", code, err
|
||||
}
|
||||
|
||||
return user.ID, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// jsonResponse sends a JSON response with code and message
|
||||
func jsonResponse(c *gin.Context, code common.ErrorCode, data interface{}, message string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": data,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// jsonError sends a JSON error response
|
||||
func jsonError(c *gin.Context, code common.ErrorCode, message string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": nil,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// HTTPError represents an HTTP error
|
||||
type HTTPError struct {
|
||||
Code common.ErrorCode
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *HTTPError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrMissingAuth indicates missing authorization header
|
||||
ErrMissingAuth = &HTTPError{Code: common.CodeUnauthorized, Message: "Missing Authorization header"}
|
||||
// ErrInvalidToken indicates invalid access token
|
||||
ErrInvalidToken = &HTTPError{Code: common.CodeUnauthorized, Message: "Invalid access token"}
|
||||
)
|
||||
|
||||
// CreateKB handles the create knowledge base request
|
||||
// @Summary Create Knowledge Base
|
||||
// @Description Create a new knowledge base (dataset)
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param keywords query string false "search keywords"
|
||||
// @Param page query int false "page number"
|
||||
// @Param page_size query int false "items per page"
|
||||
// @Param parser_id query string false "parser ID filter"
|
||||
// @Param orderby query string false "order by field"
|
||||
// @Param desc query bool false "descending order"
|
||||
// @Param request body service.ListKbsRequest true "filter options"
|
||||
// @Success 200 {object} service.ListKbsResponse
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body service.CreateKBRequest true "knowledge base info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/create [post]
|
||||
func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req service.CreateKBRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.kbService.CreateKB(&req, userID)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// UpdateKB handles the update knowledge base request
|
||||
// @Summary Update Knowledge Base
|
||||
// @Description Update an existing knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body service.UpdateKBRequest true "knowledge base update info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/update [post]
|
||||
func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req service.UpdateKBRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.kbService.UpdateKB(&req, userID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "authorization") {
|
||||
jsonError(c, common.CodeAuthenticationError, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// UpdateMetadataSetting handles the update metadata setting request
|
||||
// @Summary Update Metadata Setting
|
||||
// @Description Update metadata settings for a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body service.UpdateMetadataSettingRequest true "metadata setting info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/update_metadata_setting [post]
|
||||
func (h *KnowledgebaseHandler) UpdateMetadataSetting(c *gin.Context) {
|
||||
_, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req service.UpdateMetadataSettingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.kbService.UpdateMetadataSetting(&req)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// GetDetail handles the get knowledge base detail request
|
||||
// @Summary Get Knowledge Base Detail
|
||||
// @Description Get detailed information about a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id query string true "Knowledge Base ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/detail [get]
|
||||
func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Query("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.kbService.GetDetail(kbID, userID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "authorized") {
|
||||
jsonError(c, common.CodeOperatingError, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// ListKbs handles the list knowledge bases request
|
||||
// @Summary List Knowledge Bases
|
||||
// @Description List knowledge bases with pagination and filtering
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body service.ListKbsRequest true "list options"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/list [post]
|
||||
func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
|
||||
// Parse request body - allow empty body
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req service.ListKbsRequest
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Extract parameters from query or request body with defaults
|
||||
// Get parameters from request or query string
|
||||
keywords := ""
|
||||
if req.Keywords != nil {
|
||||
keywords = *req.Keywords
|
||||
@ -111,7 +304,7 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
|
||||
if req.Desc != nil {
|
||||
desc = *req.Desc
|
||||
} else if descStr := c.Query("desc"); descStr != "" {
|
||||
desc = descStr == "true"
|
||||
desc = strings.ToLower(descStr) == "true"
|
||||
}
|
||||
|
||||
var ownerIDs []string
|
||||
@ -119,40 +312,327 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
|
||||
ownerIDs = *req.OwnerIDs
|
||||
}
|
||||
|
||||
// Get access token from Authorization header
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Missing Authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
result, code, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// List knowledge bases
|
||||
result, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": result,
|
||||
"message": "success",
|
||||
})
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// DeleteKB handles the delete knowledge base request
|
||||
// @Summary Delete Knowledge Base
|
||||
// @Description Soft delete a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body object{kb_id string} true "knowledge base id"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/rm [post]
|
||||
func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
KBID string `json:"kb_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
code, err = h.kbService.DeleteKB(req.KBID, userID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "authorization") {
|
||||
jsonError(c, common.CodeAuthenticationError, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
// ListTags handles the list tags request for a knowledge base
|
||||
// @Summary List Tags
|
||||
// @Description List tags for a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id path string true "Knowledge Base ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/{kb_id}/tags [get]
|
||||
func (h *KnowledgebaseHandler) ListTags(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Param("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, []string{}, "success")
|
||||
}
|
||||
|
||||
// ListTagsFromKbs handles the list tags from multiple knowledge bases request
|
||||
// @Summary List Tags from Knowledge Bases
|
||||
// @Description List tags from multiple knowledge bases
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_ids query string true "Comma-separated Knowledge Base IDs"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/tags [get]
|
||||
func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbIDsStr := c.Query("kb_ids")
|
||||
if kbIDsStr == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
kbIDs := strings.Split(kbIDsStr, ",")
|
||||
for _, kbID := range kbIDs {
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, []string{}, "success")
|
||||
}
|
||||
|
||||
// RemoveTags handles the remove tags request
|
||||
// @Summary Remove Tags
|
||||
// @Description Remove tags from a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id path string true "Knowledge Base ID"
|
||||
// @Param request body object{tags []string} true "tags to remove"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/{kb_id}/rm_tags [post]
|
||||
func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Param("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Tags []string `json:"tags" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
// RenameTag handles the rename tag request
|
||||
// @Summary Rename Tag
|
||||
// @Description Rename a tag in a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id path string true "Knowledge Base ID"
|
||||
// @Param request body object{from_tag string, to_tag string} true "tag rename info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/{kb_id}/rename_tag [post]
|
||||
func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Param("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
FromTag string `json:"from_tag" binding:"required"`
|
||||
ToTag string `json:"to_tag" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
// KnowledgeGraph handles the get knowledge graph request
|
||||
// @Summary Get Knowledge Graph
|
||||
// @Description Get knowledge graph for a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id path string true "Knowledge Base ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/{kb_id}/knowledge_graph [get]
|
||||
func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Param("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"graph": map[string]interface{}{},
|
||||
"mind_map": map[string]interface{}{},
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// DeleteKnowledgeGraph handles the delete knowledge graph request
|
||||
// @Summary Delete Knowledge Graph
|
||||
// @Description Delete knowledge graph for a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id path string true "Knowledge Base ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/{kb_id}/knowledge_graph [delete]
|
||||
func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Param("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
// GetMeta handles the get metadata request
|
||||
// @Summary Get Metadata
|
||||
// @Description Get metadata for knowledge bases
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_ids query string true "Comma-separated Knowledge Base IDs"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/get_meta [get]
|
||||
func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbIDsStr := c.Query("kb_ids")
|
||||
if kbIDsStr == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
kbIDs := strings.Split(kbIDsStr, ",")
|
||||
for _, kbID := range kbIDs {
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success")
|
||||
}
|
||||
|
||||
// GetBasicInfo handles the get basic info request
|
||||
// @Summary Get Basic Info
|
||||
// @Description Get basic information for a knowledge base
|
||||
// @Tags knowledgebase
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param kb_id query string true "Knowledge Base ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/kb/basic_info [get]
|
||||
func (h *KnowledgebaseHandler) GetBasicInfo(c *gin.Context) {
|
||||
userID, code, err := h.getUserID(c)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
kbID := c.Query("kb_id")
|
||||
if kbID == "" {
|
||||
jsonError(c, common.CodeDataError, "kb_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbService.Accessible(kbID, userID) {
|
||||
jsonError(c, common.CodeAuthenticationError, "No authorization.")
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success")
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
@ -60,50 +61,112 @@ func NewLLMHandler(llmService *service.LLMService, userService *service.UserServ
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/llm/my_llms [get]
|
||||
func (h *LLMHandler) GetMyLLMs(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Missing Authorization header",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Unauthorized!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get tenant ID from user
|
||||
tenantID := user.ID
|
||||
if tenantID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "User has no tenant ID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse include_details query parameter
|
||||
includeDetailsStr := c.DefaultQuery("include_details", "false")
|
||||
includeDetails := includeDetailsStr == "true"
|
||||
|
||||
// Get LLMs for tenant
|
||||
llms, err := h.llmService.GetMyLLMs(tenantID, includeDetails)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to get LLMs",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": llms,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": llms,
|
||||
})
|
||||
}
|
||||
|
||||
// SetAPIKey set API key for a LLM factory
|
||||
// @Summary Set API Key
|
||||
// @Description Set API key for a LLM factory and test connectivity
|
||||
// @Tags llm
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body service.SetAPIKeyRequest true "API Key configuration"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/llm/set_api_key [post]
|
||||
func (h *LLMHandler) SetAPIKey(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Unauthorized!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req service.SetAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "Invalid request: " + err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := user.ID
|
||||
result, err := h.llmService.SetAPIKey(tenantID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Verify {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": result,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
|
||||
@ -198,52 +261,43 @@ func (h *LLMHandler) Factories(c *gin.Context) {
|
||||
// @Success 200 {object} map[string][]service.LLMListItem
|
||||
// @Router /v1/llm/list [get]
|
||||
func (h *LLMHandler) ListApp(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Missing Authorization header",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Unauthorized!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get tenant ID from user
|
||||
tenantID := user.ID
|
||||
if tenantID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "User has no tenant ID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse model_type query parameter
|
||||
modelType := c.Query("model_type")
|
||||
|
||||
// Get LLM list
|
||||
llms, err := h.llmService.ListLLMs(tenantID, modelType)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": llms,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": llms,
|
||||
})
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
@ -48,44 +49,49 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/user/tenant_info [get]
|
||||
func (h *TenantHandler) TenantInfo(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Missing Authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Get user by token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Unauthorized!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get tenant info
|
||||
tenantInfo, err := h.tenantService.GetTenantInfo(user.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to get tenant information",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if tenantInfo == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "Tenant not found",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"message": "Tenant not found!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": tenantInfo,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": tenantInfo,
|
||||
})
|
||||
}
|
||||
|
||||
@ -99,38 +105,39 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) {
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/tenant/list [get]
|
||||
func (h *TenantHandler) TenantList(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Missing Authorization header",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Unauthorized!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get tenant list
|
||||
tenantList, err := h.tenantService.GetTenantList(user.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Failed to get tenant list",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": tenantList,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": tenantList,
|
||||
})
|
||||
}
|
||||
|
||||
@ -522,3 +522,61 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) {
|
||||
"data": channels,
|
||||
})
|
||||
}
|
||||
|
||||
// SetTenantInfo update tenant information
|
||||
// @Summary Set Tenant Info
|
||||
// @Description Update tenant model configuration
|
||||
// @Tags users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param request body service.SetTenantInfoRequest true "tenant info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/user/set_tenant_info [post]
|
||||
func (h *UserHandler) SetTenantInfo(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Unauthorized!",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req service.SetTenantInfoRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.userService.SetTenantInfo(user.ID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
|
||||
157
internal/init_data/llm_init.go
Normal file
157
internal/init_data/llm_init.go
Normal file
@ -0,0 +1,157 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package init_data
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// LLMFactoryConfig represents a single LLM factory configuration
|
||||
type LLMFactoryConfig struct {
|
||||
Name string `json:"name"`
|
||||
Logo string `json:"logo"`
|
||||
Tags string `json:"tags"`
|
||||
Status string `json:"status"`
|
||||
Rank string `json:"rank"`
|
||||
LLM []LLMConfig `json:"llm"`
|
||||
}
|
||||
|
||||
// LLMConfig represents a single LLM model configuration
|
||||
type LLMConfig struct {
|
||||
LLMName string `json:"llm_name"`
|
||||
Tags string `json:"tags"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
ModelType string `json:"model_type"`
|
||||
IsTools bool `json:"is_tools"`
|
||||
}
|
||||
|
||||
// LLMFactoriesFile represents the structure of llm_factories.json
|
||||
type LLMFactoriesFile struct {
|
||||
FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"`
|
||||
}
|
||||
|
||||
// InitLLMFactory initializes LLM factories and models from JSON file
|
||||
func InitLLMFactory() error {
|
||||
configPath := filepath.Join(getProjectBaseDirectory(), "conf", "llm_factories.json")
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read llm_factories.json: %w", err)
|
||||
}
|
||||
|
||||
var fileData LLMFactoriesFile
|
||||
if err := json.Unmarshal(data, &fileData); err != nil {
|
||||
return fmt.Errorf("failed to parse llm_factories.json: %w", err)
|
||||
}
|
||||
|
||||
db := dao.DB
|
||||
|
||||
for _, factory := range fileData.FactoryLLMInfos {
|
||||
status := factory.Status
|
||||
if status == "" {
|
||||
status = "1"
|
||||
}
|
||||
|
||||
llmFactory := &model.LLMFactories{
|
||||
Name: factory.Name,
|
||||
Logo: stringPtr(factory.Logo),
|
||||
Tags: factory.Tags,
|
||||
Rank: parseInt64(factory.Rank),
|
||||
Status: &status,
|
||||
}
|
||||
|
||||
var existingFactory model.LLMFactories
|
||||
result := db.Where("name = ?", factory.Name).First(&existingFactory)
|
||||
if result.Error != nil {
|
||||
if err := db.Create(llmFactory).Error; err != nil {
|
||||
log.Printf("Failed to create LLM factory %s: %v", factory.Name, err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{
|
||||
"logo": llmFactory.Logo,
|
||||
"tags": llmFactory.Tags,
|
||||
"rank": llmFactory.Rank,
|
||||
"status": llmFactory.Status,
|
||||
}).Error; err != nil {
|
||||
log.Printf("Failed to update LLM factory %s: %v", factory.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, llm := range factory.LLM {
|
||||
llmStatus := "1"
|
||||
llmModel := &model.LLM{
|
||||
LLMName: llm.LLMName,
|
||||
ModelType: llm.ModelType,
|
||||
FID: factory.Name,
|
||||
MaxTokens: llm.MaxTokens,
|
||||
Tags: llm.Tags,
|
||||
IsTools: llm.IsTools,
|
||||
Status: &llmStatus,
|
||||
}
|
||||
|
||||
var existingLLM model.LLM
|
||||
result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM)
|
||||
if result.Error != nil {
|
||||
if err := db.Create(llmModel).Error; err != nil {
|
||||
log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err)
|
||||
}
|
||||
} else {
|
||||
if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{
|
||||
"model_type": llmModel.ModelType,
|
||||
"max_tokens": llmModel.MaxTokens,
|
||||
"tags": llmModel.Tags,
|
||||
"is_tools": llmModel.IsTools,
|
||||
"status": llmModel.Status,
|
||||
}).Error; err != nil {
|
||||
log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("LLM factories initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func getProjectBaseDirectory() string {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return cwd
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func parseInt64(s string) int64 {
|
||||
var result int64
|
||||
fmt.Sscanf(s, "%d", &result)
|
||||
return result
|
||||
}
|
||||
@ -18,7 +18,84 @@ package model
|
||||
|
||||
import "time"
|
||||
|
||||
// Knowledgebase knowledge base model
|
||||
// DatasetNameLimit is the maximum length for dataset name
|
||||
const DatasetNameLimit = 128
|
||||
|
||||
// Status represents the status enum values
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusValid indicates a valid/active record
|
||||
StatusValid Status = "1"
|
||||
// StatusInvalid indicates a deleted/inactive record
|
||||
StatusInvalid Status = "0"
|
||||
)
|
||||
|
||||
// TenantPermission represents the permission level for tenant access
|
||||
type TenantPermission string
|
||||
|
||||
const (
|
||||
// TenantPermissionMe indicates only the creator can access
|
||||
TenantPermissionMe TenantPermission = "me"
|
||||
// TenantPermissionTeam indicates all team members can access
|
||||
TenantPermissionTeam TenantPermission = "team"
|
||||
)
|
||||
|
||||
// ParserType represents the document parser type
|
||||
type ParserType string
|
||||
|
||||
const (
|
||||
ParserTypePresentation ParserType = "presentation"
|
||||
ParserTypeLaws ParserType = "laws"
|
||||
ParserTypeManual ParserType = "manual"
|
||||
ParserTypePaper ParserType = "paper"
|
||||
ParserTypeResume ParserType = "resume"
|
||||
ParserTypeBook ParserType = "book"
|
||||
ParserTypeQA ParserType = "qa"
|
||||
ParserTypeTable ParserType = "table"
|
||||
ParserTypeNaive ParserType = "naive"
|
||||
ParserTypePicture ParserType = "picture"
|
||||
ParserTypeOne ParserType = "one"
|
||||
ParserTypeAudio ParserType = "audio"
|
||||
ParserTypeEmail ParserType = "email"
|
||||
ParserTypeKG ParserType = "knowledge_graph"
|
||||
ParserTypeTag ParserType = "tag"
|
||||
)
|
||||
|
||||
// TaskStatus represents the status of a processing task
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusUnstart TaskStatus = "0"
|
||||
TaskStatusRunning TaskStatus = "1"
|
||||
TaskStatusCancel TaskStatus = "2"
|
||||
TaskStatusDone TaskStatus = "3"
|
||||
TaskStatusFail TaskStatus = "4"
|
||||
TaskStatusSchedule TaskStatus = "5"
|
||||
)
|
||||
|
||||
// PipelineTaskType represents the type of pipeline task
|
||||
type PipelineTaskType string
|
||||
|
||||
const (
|
||||
PipelineTaskTypeParse PipelineTaskType = "Parse"
|
||||
PipelineTaskTypeDownload PipelineTaskType = "Download"
|
||||
PipelineTaskTypeRAPTOR PipelineTaskType = "RAPTOR"
|
||||
PipelineTaskTypeGraphRAG PipelineTaskType = "GraphRAG"
|
||||
PipelineTaskTypeMindmap PipelineTaskType = "Mindmap"
|
||||
PipelineTaskTypeMemory PipelineTaskType = "Memory"
|
||||
)
|
||||
|
||||
// FileSource represents the source of a file
|
||||
type FileSource string
|
||||
|
||||
const (
|
||||
FileSourceLocal FileSource = ""
|
||||
FileSourceKnowledgebase FileSource = "knowledgebase"
|
||||
FileSourceS3 FileSource = "s3"
|
||||
)
|
||||
|
||||
// Knowledgebase represents the knowledge base model
|
||||
type Knowledgebase struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"`
|
||||
@ -27,7 +104,6 @@ type Knowledgebase struct {
|
||||
Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"`
|
||||
Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"`
|
||||
EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"`
|
||||
TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"`
|
||||
Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"`
|
||||
CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"`
|
||||
DocNum int64 `gorm:"column:doc_num;default:0;index" json:"doc_num"`
|
||||
@ -37,7 +113,7 @@ type Knowledgebase struct {
|
||||
VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3;index" json:"vector_similarity_weight"`
|
||||
ParserID string `gorm:"column:parser_id;size:32;not null;default:naive;index" json:"parser_id"`
|
||||
PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"`
|
||||
ParserConfig JSONMap `gorm:"column:parser_config;type:longtext;not null" json:"parser_config"`
|
||||
ParserConfig JSONMap `gorm:"column:parser_config;type:json" json:"parser_config"`
|
||||
Pagerank int64 `gorm:"column:pagerank;default:0" json:"pagerank"`
|
||||
GraphragTaskID *string `gorm:"column:graphrag_task_id;size:32;index" json:"graphrag_task_id,omitempty"`
|
||||
GraphragTaskFinishAt *time.Time `gorm:"column:graphrag_task_finish_at" json:"graphrag_task_finish_at,omitempty"`
|
||||
@ -49,12 +125,118 @@ type Knowledgebase struct {
|
||||
BaseModel
|
||||
}
|
||||
|
||||
// TableName specify table name
|
||||
// TableName returns the table name for Knowledgebase model
|
||||
func (Knowledgebase) TableName() string {
|
||||
return "knowledgebase"
|
||||
}
|
||||
|
||||
// InvitationCode invitation code model
|
||||
// ToMap converts Knowledgebase to a map for JSON response
|
||||
func (kb *Knowledgebase) ToMap() map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"id": kb.ID,
|
||||
"tenant_id": kb.TenantID,
|
||||
"name": kb.Name,
|
||||
"embd_id": kb.EmbdID,
|
||||
"permission": kb.Permission,
|
||||
"created_by": kb.CreatedBy,
|
||||
"doc_num": kb.DocNum,
|
||||
"token_num": kb.TokenNum,
|
||||
"chunk_num": kb.ChunkNum,
|
||||
"similarity_threshold": kb.SimilarityThreshold,
|
||||
"vector_similarity_weight": kb.VectorSimilarityWeight,
|
||||
"parser_id": kb.ParserID,
|
||||
"parser_config": kb.ParserConfig,
|
||||
"pagerank": kb.Pagerank,
|
||||
"create_time": kb.CreateTime,
|
||||
}
|
||||
|
||||
if kb.Avatar != nil {
|
||||
result["avatar"] = *kb.Avatar
|
||||
}
|
||||
if kb.Language != nil {
|
||||
result["language"] = *kb.Language
|
||||
}
|
||||
if kb.Description != nil {
|
||||
result["description"] = *kb.Description
|
||||
}
|
||||
if kb.PipelineID != nil {
|
||||
result["pipeline_id"] = *kb.PipelineID
|
||||
}
|
||||
if kb.GraphragTaskID != nil {
|
||||
result["graphrag_task_id"] = *kb.GraphragTaskID
|
||||
}
|
||||
if kb.GraphragTaskFinishAt != nil {
|
||||
result["graphrag_task_finish_at"] = kb.GraphragTaskFinishAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if kb.RaptorTaskID != nil {
|
||||
result["raptor_task_id"] = *kb.RaptorTaskID
|
||||
}
|
||||
if kb.RaptorTaskFinishAt != nil {
|
||||
result["raptor_task_finish_at"] = kb.RaptorTaskFinishAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if kb.MindmapTaskID != nil {
|
||||
result["mindmap_task_id"] = *kb.MindmapTaskID
|
||||
}
|
||||
if kb.MindmapTaskFinishAt != nil {
|
||||
result["mindmap_task_finish_at"] = kb.MindmapTaskFinishAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if kb.UpdateTime != nil {
|
||||
result["update_time"] = *kb.UpdateTime
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// KnowledgebaseDetail represents detailed knowledge base information with joined data
|
||||
type KnowledgebaseDetail struct {
|
||||
ID string `json:"id"`
|
||||
EmbdID string `json:"embd_id"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Permission string `json:"permission"`
|
||||
DocNum int64 `json:"doc_num"`
|
||||
TokenNum int64 `json:"token_num"`
|
||||
ChunkNum int64 `json:"chunk_num"`
|
||||
ParserID string `json:"parser_id"`
|
||||
PipelineID *string `json:"pipeline_id,omitempty"`
|
||||
PipelineName *string `json:"pipeline_name,omitempty"`
|
||||
PipelineAvatar *string `json:"pipeline_avatar,omitempty"`
|
||||
ParserConfig JSONMap `json:"parser_config"`
|
||||
Pagerank int64 `json:"pagerank"`
|
||||
GraphragTaskID *string `json:"graphrag_task_id,omitempty"`
|
||||
GraphragTaskFinishAt *string `json:"graphrag_task_finish_at,omitempty"`
|
||||
RaptorTaskID *string `json:"raptor_task_id,omitempty"`
|
||||
RaptorTaskFinishAt *string `json:"raptor_task_finish_at,omitempty"`
|
||||
MindmapTaskID *string `json:"mindmap_task_id,omitempty"`
|
||||
MindmapTaskFinishAt *string `json:"mindmap_task_finish_at,omitempty"`
|
||||
CreateTime *int64 `json:"create_time,omitempty"`
|
||||
UpdateTime *int64 `json:"update_time,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
Connectors []string `json:"connectors"`
|
||||
}
|
||||
|
||||
// KnowledgebaseListItem represents a knowledge base item in list responses
|
||||
type KnowledgebaseListItem struct {
|
||||
ID string `json:"id"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
Permission string `json:"permission"`
|
||||
DocNum int64 `json:"doc_num"`
|
||||
TokenNum int64 `json:"token_num"`
|
||||
ChunkNum int64 `json:"chunk_num"`
|
||||
ParserID string `json:"parser_id"`
|
||||
EmbdID string `json:"embd_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
TenantAvatar *string `json:"tenant_avatar,omitempty"`
|
||||
UpdateTime *int64 `json:"update_time,omitempty"`
|
||||
}
|
||||
|
||||
// InvitationCode represents the invitation code model
|
||||
type InvitationCode struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Code string `gorm:"column:code;size:32;not null;index" json:"code"`
|
||||
@ -65,7 +247,7 @@ type InvitationCode struct {
|
||||
BaseModel
|
||||
}
|
||||
|
||||
// TableName specify table name
|
||||
// TableName returns the table name for InvitationCode model
|
||||
func (InvitationCode) TableName() string {
|
||||
return "invitation_code"
|
||||
}
|
||||
|
||||
@ -64,13 +64,14 @@ func (TenantLangfuse) TableName() string {
|
||||
|
||||
// MyLLM represents LLM information for a tenant with factory details
|
||||
type MyLLM struct {
|
||||
ID string `gorm:"column:id" json:"id"`
|
||||
LLMFactory string `gorm:"column:llm_factory" json:"llm_factory"`
|
||||
Logo *string `gorm:"column:logo" json:"logo,omitempty"`
|
||||
Tags string `gorm:"column:tags" json:"tags"`
|
||||
ModelType string `gorm:"column:model_type" json:"model_type"`
|
||||
LLMName string `gorm:"column:llm_name" json:"llm_name"`
|
||||
UsedTokens int64 `gorm:"column:used_tokens" json:"used_tokens"`
|
||||
Status string `gorm:"column:status" json:"status"`
|
||||
APIBase string `gorm:"column:api_base" json:"api_base,omitempty"`
|
||||
MaxTokens int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"`
|
||||
Tags *string `gorm:"column:tags" json:"tags"`
|
||||
ModelType *string `gorm:"column:model_type" json:"model_type"`
|
||||
LLMName *string `gorm:"column:llm_name" json:"llm_name"`
|
||||
UsedTokens *int64 `gorm:"column:used_tokens" json:"used_tokens"`
|
||||
Status *string `gorm:"column:status" json:"status"`
|
||||
APIBase *string `gorm:"column:api_base" json:"api_base,omitempty"`
|
||||
MaxTokens *int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
@ -101,6 +101,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
engine.POST("/v1/user/setting", r.userHandler.Setting)
|
||||
// User change password endpoint
|
||||
engine.POST("/v1/user/setting/password", r.userHandler.ChangePassword)
|
||||
// User set tenant info endpoint
|
||||
engine.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo)
|
||||
|
||||
// API v1 route group
|
||||
v1 := engine.Group("/api/v1")
|
||||
@ -134,7 +136,25 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// Knowledge base routes
|
||||
kb := engine.Group("/v1/kb")
|
||||
{
|
||||
kb.POST("/create", r.knowledgebaseHandler.CreateKB)
|
||||
kb.POST("/update", r.knowledgebaseHandler.UpdateKB)
|
||||
kb.POST("/update_metadata_setting", r.knowledgebaseHandler.UpdateMetadataSetting)
|
||||
kb.GET("/detail", r.knowledgebaseHandler.GetDetail)
|
||||
kb.POST("/list", r.knowledgebaseHandler.ListKbs)
|
||||
kb.POST("/rm", r.knowledgebaseHandler.DeleteKB)
|
||||
kb.GET("/tags", r.knowledgebaseHandler.ListTagsFromKbs)
|
||||
kb.GET("/get_meta", r.knowledgebaseHandler.GetMeta)
|
||||
kb.GET("/basic_info", r.knowledgebaseHandler.GetBasicInfo)
|
||||
|
||||
// KB ID specific routes
|
||||
kbByID := kb.Group("/:kb_id")
|
||||
{
|
||||
kbByID.GET("/tags", r.knowledgebaseHandler.ListTags)
|
||||
kbByID.POST("/rm_tags", r.knowledgebaseHandler.RemoveTags)
|
||||
kbByID.POST("/rename_tag", r.knowledgebaseHandler.RenameTag)
|
||||
kbByID.GET("/knowledge_graph", r.knowledgebaseHandler.KnowledgeGraph)
|
||||
kbByID.DELETE("/knowledge_graph", r.knowledgebaseHandler.DeleteKnowledgeGraph)
|
||||
}
|
||||
}
|
||||
|
||||
// Chunk routes
|
||||
@ -149,6 +169,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
llm.GET("/my_llms", r.llmHandler.GetMyLLMs)
|
||||
llm.GET("/factories", r.llmHandler.Factories)
|
||||
llm.GET("/list", r.llmHandler.ListApp)
|
||||
llm.POST("/set_api_key", r.llmHandler.SetAPIKey)
|
||||
}
|
||||
|
||||
// Chat routes
|
||||
|
||||
@ -40,6 +40,29 @@ type Config struct {
|
||||
DocEngine DocEngineConfig `mapstructure:"doc_engine"`
|
||||
RegisterEnabled int `mapstructure:"register_enabled"`
|
||||
OAuth map[string]OAuthConfig `mapstructure:"oauth"`
|
||||
UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"`
|
||||
}
|
||||
|
||||
// UserDefaultLLMConfig user default LLM configuration
|
||||
type UserDefaultLLMConfig struct {
|
||||
DefaultModels DefaultModelsConfig `mapstructure:"default_models"`
|
||||
}
|
||||
|
||||
// DefaultModelsConfig default models configuration
|
||||
type DefaultModelsConfig struct {
|
||||
ChatModel ModelConfig `mapstructure:"chat_model"`
|
||||
EmbeddingModel ModelConfig `mapstructure:"embedding_model"`
|
||||
RerankModel ModelConfig `mapstructure:"rerank_model"`
|
||||
ASRModel ModelConfig `mapstructure:"asr_model"`
|
||||
Image2TextModel ModelConfig `mapstructure:"image2text_model"`
|
||||
}
|
||||
|
||||
// ModelConfig model configuration
|
||||
type ModelConfig struct {
|
||||
Name string `mapstructure:"name"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Factory string `mapstructure:"factory"`
|
||||
}
|
||||
|
||||
// OAuthConfig OAuth configuration for a channel
|
||||
@ -414,6 +437,45 @@ func Init(configPath string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Map user_default_llm section to UserDefaultLLMConfig
|
||||
if v.IsSet("user_default_llm") {
|
||||
userDefaultLLMConfig := v.Sub("user_default_llm")
|
||||
if userDefaultLLMConfig != nil {
|
||||
if defaultModels := userDefaultLLMConfig.Sub("default_models"); defaultModels != nil {
|
||||
globalConfig.UserDefaultLLM.DefaultModels.ChatModel = ModelConfig{
|
||||
Name: defaultModels.GetString("chat_model.name"),
|
||||
APIKey: defaultModels.GetString("chat_model.api_key"),
|
||||
BaseURL: defaultModels.GetString("chat_model.base_url"),
|
||||
Factory: defaultModels.GetString("chat_model.factory"),
|
||||
}
|
||||
globalConfig.UserDefaultLLM.DefaultModels.EmbeddingModel = ModelConfig{
|
||||
Name: defaultModels.GetString("embedding_model.name"),
|
||||
APIKey: defaultModels.GetString("embedding_model.api_key"),
|
||||
BaseURL: defaultModels.GetString("embedding_model.base_url"),
|
||||
Factory: defaultModels.GetString("embedding_model.factory"),
|
||||
}
|
||||
globalConfig.UserDefaultLLM.DefaultModels.RerankModel = ModelConfig{
|
||||
Name: defaultModels.GetString("rerank_model.name"),
|
||||
APIKey: defaultModels.GetString("rerank_model.api_key"),
|
||||
BaseURL: defaultModels.GetString("rerank_model.base_url"),
|
||||
Factory: defaultModels.GetString("rerank_model.factory"),
|
||||
}
|
||||
globalConfig.UserDefaultLLM.DefaultModels.ASRModel = ModelConfig{
|
||||
Name: defaultModels.GetString("asr_model.name"),
|
||||
APIKey: defaultModels.GetString("asr_model.api_key"),
|
||||
BaseURL: defaultModels.GetString("asr_model.base_url"),
|
||||
Factory: defaultModels.GetString("asr_model.factory"),
|
||||
}
|
||||
globalConfig.UserDefaultLLM.DefaultModels.Image2TextModel = ModelConfig{
|
||||
Name: defaultModels.GetString("image2text_model.name"),
|
||||
APIKey: defaultModels.GetString("image2text_model.api_key"),
|
||||
BaseURL: defaultModels.GetString("image2text_model.base_url"),
|
||||
Factory: defaultModels.GetString("image2text_model.factory"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ func (s *ChatService) ListChats(userID string, status string) (*ListChatsRespons
|
||||
}
|
||||
|
||||
// Enrich with knowledge base names
|
||||
var chatsWithKBNames []*ChatWithKBNames
|
||||
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
|
||||
@ -148,7 +148,7 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
|
||||
}
|
||||
|
||||
// Enrich with knowledge base names
|
||||
var chatsWithKBNames []*ChatWithKBNames
|
||||
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
|
||||
|
||||
@ -17,25 +17,76 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/utility"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// KnowledgebaseService knowledge base service
|
||||
// KnowledgebaseService service class for managing dataset operations
|
||||
type KnowledgebaseService struct {
|
||||
kbDAO *dao.KnowledgebaseDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
userDAO *dao.UserDAO
|
||||
tenantDAO *dao.TenantDAO
|
||||
connectorDAO *dao.ConnectorDAO
|
||||
}
|
||||
|
||||
// NewKnowledgebaseService create knowledge base service
|
||||
// NewKnowledgebaseService creates a new knowledge base service
|
||||
func NewKnowledgebaseService() *KnowledgebaseService {
|
||||
return &KnowledgebaseService{
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
userDAO: dao.NewUserDAO(),
|
||||
tenantDAO: dao.NewTenantDAO(),
|
||||
connectorDAO: dao.NewConnectorDAO(),
|
||||
}
|
||||
}
|
||||
|
||||
// ListKbsRequest list knowledge bases request
|
||||
// CreateKBRequest represents the request for creating a knowledge base
|
||||
type CreateKBRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ParserID *string `json:"parser_id,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
Permission *string `json:"permission,omitempty"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
|
||||
}
|
||||
|
||||
// CreateKBResponse represents the response for creating a knowledge base
|
||||
type CreateKBResponse struct {
|
||||
KBID string `json:"kb_id"`
|
||||
}
|
||||
|
||||
// UpdateKBRequest represents the request for updating a knowledge base
|
||||
type UpdateKBRequest struct {
|
||||
KBID string `json:"kb_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description *string `json:"description"`
|
||||
ParserID string `json:"parser_id" binding:"required"`
|
||||
Permission *string `json:"permission,omitempty"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
Pagerank *int64 `json:"pagerank,omitempty"`
|
||||
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
|
||||
Connectors []string `json:"connectors,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateMetadataSettingRequest represents the request for updating metadata settings
|
||||
type UpdateMetadataSettingRequest struct {
|
||||
KBID string `json:"kb_id" binding:"required"`
|
||||
Metadata map[string]interface{} `json:"metadata" binding:"required"`
|
||||
EnableMetadata *bool `json:"enable_metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ListKbsRequest represents the request for listing knowledge bases
|
||||
type ListKbsRequest struct {
|
||||
Keywords *string `json:"keywords,omitempty"`
|
||||
Page *int `json:"page,omitempty"`
|
||||
@ -46,37 +97,461 @@ type ListKbsRequest struct {
|
||||
OwnerIDs *[]string `json:"owner_ids,omitempty"`
|
||||
}
|
||||
|
||||
// ListKbsResponse list knowledge bases response
|
||||
// ListKbsResponse represents the response for listing knowledge bases
|
||||
type ListKbsResponse struct {
|
||||
KBs []*model.Knowledgebase `json:"kbs"`
|
||||
Total int64 `json:"total"`
|
||||
KBs []map[string]interface{} `json:"kbs"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
// ListKbs list knowledge bases
|
||||
func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
// CreateKB creates a new knowledge base
|
||||
// This matches the Python create endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (*CreateKBResponse, common.ErrorCode, error) {
|
||||
// Validate name is a string
|
||||
if !isValidString(req.Name) {
|
||||
return nil, common.CodeDataError, errors.New("Dataset name must be string.")
|
||||
}
|
||||
|
||||
// Trim and validate name
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
|
||||
}
|
||||
|
||||
// Check name length (using UTF-8 byte length like Python)
|
||||
if len(name) > model.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
|
||||
}
|
||||
|
||||
// Verify tenant exists
|
||||
tenant, err := s.tenantDAO.GetByID(tenantID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("Tenant not found.")
|
||||
}
|
||||
|
||||
// Deduplicate name within tenant
|
||||
duplicateName := s.kbDAO.DuplicateName(name, tenantID)
|
||||
|
||||
// Get parser ID (default to "naive")
|
||||
parserID := "naive"
|
||||
if req.ParserID != nil && *req.ParserID != "" {
|
||||
parserID = *req.ParserID
|
||||
}
|
||||
|
||||
// Get parser config with defaults
|
||||
parserConfig := getParserConfig(parserID, req.ParserConfig)
|
||||
parserConfig["llm_id"] = tenant.LLMID
|
||||
|
||||
// Generate KB ID
|
||||
kbID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// Create knowledge base model
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
kb := &model.Knowledgebase{
|
||||
ID: kbID,
|
||||
Name: duplicateName,
|
||||
TenantID: tenantID,
|
||||
CreatedBy: tenantID,
|
||||
ParserID: parserID,
|
||||
ParserConfig: parserConfig,
|
||||
Permission: "me",
|
||||
EmbdID: "",
|
||||
}
|
||||
kb.CreateTime = &now
|
||||
kb.UpdateTime = &now
|
||||
kb.CreateDate = &nowDate
|
||||
kb.UpdateDate = &nowDate
|
||||
status := string(model.StatusValid)
|
||||
kb.Status = &status
|
||||
|
||||
// Set optional fields
|
||||
if req.Description != nil {
|
||||
kb.Description = req.Description
|
||||
}
|
||||
if req.Language != nil {
|
||||
kb.Language = req.Language
|
||||
}
|
||||
if req.Permission != nil {
|
||||
kb.Permission = *req.Permission
|
||||
}
|
||||
if req.Avatar != nil {
|
||||
kb.Avatar = req.Avatar
|
||||
}
|
||||
|
||||
// Create in database
|
||||
if err := s.kbDAO.Create(kb); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create knowledge base: %w", err)
|
||||
}
|
||||
|
||||
return &CreateKBResponse{KBID: kbID}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// UpdateKB updates an existing knowledge base
|
||||
// This matches the Python update endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) UpdateKB(req *UpdateKBRequest, userID string) (map[string]interface{}, common.ErrorCode, error) {
|
||||
// Validate name is a string
|
||||
if !isValidString(req.Name) {
|
||||
return nil, common.CodeDataError, errors.New("Dataset name must be string.")
|
||||
}
|
||||
|
||||
// Trim and validate name
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
|
||||
}
|
||||
|
||||
// Check name length
|
||||
if len(name) > model.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
|
||||
}
|
||||
|
||||
// Check authorization
|
||||
if !s.kbDAO.Accessible4Deletion(req.KBID, userID) {
|
||||
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": req.KBID})
|
||||
if err != nil || len(kbs) == 0 {
|
||||
return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
|
||||
}
|
||||
|
||||
// Get existing KB
|
||||
kb, err := s.kbDAO.GetByID(req.KBID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("can't find this dataset")
|
||||
}
|
||||
|
||||
// Check for duplicate name
|
||||
if strings.ToLower(name) != strings.ToLower(kb.Name) {
|
||||
existingKB, _ := s.kbDAO.GetByName(name, userID)
|
||||
if existingKB != nil {
|
||||
return nil, common.CodeDataError, errors.New("duplicated dataset name")
|
||||
}
|
||||
}
|
||||
|
||||
// Build updates
|
||||
updates := map[string]interface{}{
|
||||
"name": name,
|
||||
"parser_id": req.ParserID,
|
||||
}
|
||||
|
||||
if req.Description != nil {
|
||||
updates["description"] = *req.Description
|
||||
}
|
||||
if req.Permission != nil {
|
||||
updates["permission"] = *req.Permission
|
||||
}
|
||||
if req.Language != nil {
|
||||
updates["language"] = *req.Language
|
||||
}
|
||||
if req.Avatar != nil {
|
||||
updates["avatar"] = *req.Avatar
|
||||
}
|
||||
if req.Pagerank != nil {
|
||||
updates["pagerank"] = *req.Pagerank
|
||||
}
|
||||
if req.ParserConfig != nil {
|
||||
updates["parser_config"] = req.ParserConfig
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
updates["update_time"] = now
|
||||
updates["update_date"] = nowDate
|
||||
|
||||
// Update in database
|
||||
if err := s.kbDAO.UpdateByID(req.KBID, updates); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update knowledge base: %w", err)
|
||||
}
|
||||
|
||||
// Get updated KB
|
||||
updatedKB, err := s.kbDAO.GetByID(req.KBID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("database error (knowledgebase rename)")
|
||||
}
|
||||
|
||||
result := updatedKB.ToMap()
|
||||
result["connectors"] = req.Connectors
|
||||
|
||||
return result, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// UpdateMetadataSetting updates the metadata settings for a knowledge base
|
||||
func (s *KnowledgebaseService) UpdateMetadataSetting(req *UpdateMetadataSettingRequest) (map[string]interface{}, common.ErrorCode, error) {
|
||||
kb, err := s.kbDAO.GetByID(req.KBID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("database error (knowledgebase not found)")
|
||||
}
|
||||
|
||||
parserConfig := kb.ParserConfig
|
||||
if parserConfig == nil {
|
||||
parserConfig = make(map[string]interface{})
|
||||
}
|
||||
|
||||
parserConfig["metadata"] = req.Metadata
|
||||
enableMetadata := true
|
||||
if req.EnableMetadata != nil {
|
||||
enableMetadata = *req.EnableMetadata
|
||||
}
|
||||
parserConfig["enable_metadata"] = enableMetadata
|
||||
|
||||
if err := s.kbDAO.UpdateParserConfig(req.KBID, parserConfig); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update metadata setting: %w", err)
|
||||
}
|
||||
|
||||
result := kb.ToMap()
|
||||
result["parser_config"] = parserConfig
|
||||
|
||||
return result, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// GetDetail retrieves detailed information about a knowledge base
|
||||
// This matches the Python kb_detail endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*model.KnowledgebaseDetail, common.ErrorCode, error) {
|
||||
// Check authorization
|
||||
if !s.kbDAO.Accessible(kbID, userID) {
|
||||
return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
|
||||
}
|
||||
|
||||
// Get detail
|
||||
detail, err := s.kbDAO.GetDetail(kbID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("can't find this dataset")
|
||||
}
|
||||
|
||||
// Set connectors (empty for now)
|
||||
detail.Connectors = []string{}
|
||||
|
||||
return detail, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// ListKbs lists knowledge bases with pagination and filtering
|
||||
// This matches the Python list endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, common.ErrorCode, error) {
|
||||
var kbs []*model.KnowledgebaseListItem
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
// If owner IDs are provided, filter by them
|
||||
if ownerIDs != nil && len(ownerIDs) > 0 {
|
||||
kbs, total, err = s.kbDAO.ListByOwnerIDs(ownerIDs, page, pageSize, orderby, desc, keywords, parserID)
|
||||
if len(ownerIDs) > 0 {
|
||||
// List by owner IDs
|
||||
kbs, total, err = s.kbDAO.GetByTenantIDs(ownerIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
|
||||
} else {
|
||||
// Get tenant IDs by user ID
|
||||
// Get tenant IDs for user
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
kbs, total, err = s.kbDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
|
||||
kbs, total, err = s.kbDAO.GetByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
// Convert to map slice
|
||||
kbMaps := make([]map[string]interface{}, len(kbs))
|
||||
for i, kb := range kbs {
|
||||
kbMaps[i] = map[string]interface{}{
|
||||
"id": kb.ID,
|
||||
"avatar": kb.Avatar,
|
||||
"name": kb.Name,
|
||||
"language": kb.Language,
|
||||
"description": kb.Description,
|
||||
"tenant_id": kb.TenantID,
|
||||
"permission": kb.Permission,
|
||||
"doc_num": kb.DocNum,
|
||||
"token_num": kb.TokenNum,
|
||||
"chunk_num": kb.ChunkNum,
|
||||
"parser_id": kb.ParserID,
|
||||
"embd_id": kb.EmbdID,
|
||||
"nickname": kb.Nickname,
|
||||
"tenant_avatar": kb.TenantAvatar,
|
||||
"update_time": kb.UpdateTime,
|
||||
}
|
||||
}
|
||||
|
||||
return &ListKbsResponse{
|
||||
KBs: kbs,
|
||||
KBs: kbMaps,
|
||||
Total: total,
|
||||
}, nil
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// DeleteKB soft deletes a knowledge base
|
||||
// This matches the Python rm endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) DeleteKB(kbID, userID string) (common.ErrorCode, error) {
|
||||
// Check authorization
|
||||
if !s.kbDAO.Accessible4Deletion(kbID, userID) {
|
||||
return common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": kbID})
|
||||
if err != nil || len(kbs) == 0 {
|
||||
return common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
|
||||
}
|
||||
|
||||
// Soft delete
|
||||
if err := s.kbDAO.Delete(kbID); err != nil {
|
||||
return common.CodeServerError, fmt.Errorf("database error (knowledgebase removal): %w", err)
|
||||
}
|
||||
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// Accessible checks if a knowledge base is accessible by a user
|
||||
func (s *KnowledgebaseService) Accessible(kbID, userID string) bool {
|
||||
return s.kbDAO.Accessible(kbID, userID)
|
||||
}
|
||||
|
||||
// GetByID retrieves a knowledge base by ID
|
||||
func (s *KnowledgebaseService) GetByID(kbID string) (*model.Knowledgebase, error) {
|
||||
return s.kbDAO.GetByID(kbID)
|
||||
}
|
||||
|
||||
// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant
|
||||
func (s *KnowledgebaseService) GetKBIDsByTenantID(tenantID string) ([]string, error) {
|
||||
return s.kbDAO.GetKBIDsByTenantID(tenantID)
|
||||
}
|
||||
|
||||
// isValidString checks if a value is a non-empty string
|
||||
func isValidString(v interface{}) bool {
|
||||
str, ok := v.(string)
|
||||
return ok && str != ""
|
||||
}
|
||||
|
||||
// getParserConfig returns the parser configuration with defaults
|
||||
// This matches the Python get_parser_config function
|
||||
func getParserConfig(parserID string, customConfig map[string]interface{}) map[string]interface{} {
|
||||
config := map[string]interface{}{
|
||||
"pages": [][]int{{1, 1000000}},
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
}
|
||||
|
||||
switch parserID {
|
||||
case "table":
|
||||
config["layout_recognize"] = false
|
||||
config["chunk_token_num"] = 128
|
||||
config["delimiter"] = "\n!?;。;!?"
|
||||
config["html4excel"] = false
|
||||
case "naive":
|
||||
config["chunk_token_num"] = 128
|
||||
config["delimiter"] = "\n!?;。;!?"
|
||||
config["html4excel"] = false
|
||||
default:
|
||||
config["raptor"] = map[string]interface{}{
|
||||
"use_raptor": false,
|
||||
}
|
||||
}
|
||||
|
||||
// Merge custom config if provided
|
||||
if customConfig != nil {
|
||||
config = mergeParserConfig(config, customConfig)
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// mergeParserConfig merges two parser configurations
|
||||
func mergeParserConfig(base, override map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
for k, v := range override {
|
||||
if existing, ok := result[k]; ok {
|
||||
if existingMap, ok := existing.(map[string]interface{}); ok {
|
||||
if newMap, ok := v.(map[string]interface{}); ok {
|
||||
result[k] = mergeParserConfig(existingMap, newMap)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GenerateUUID generates a UUID string without dashes
|
||||
func GenerateUUID() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
|
||||
// GetUserByToken gets user by authorization token
|
||||
func (s *KnowledgebaseService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) {
|
||||
userService := NewUserService()
|
||||
return userService.GetUserByToken(authorization)
|
||||
}
|
||||
|
||||
// GetUserByID gets user by ID
|
||||
func (s *KnowledgebaseService) GetUserByID(id string) (*model.User, error) {
|
||||
return s.userDAO.GetByAccessToken(id)
|
||||
}
|
||||
|
||||
// GetTenantIDsByUserID gets tenant IDs for a user
|
||||
func (s *KnowledgebaseService) GetTenantIDsByUserID(userID string) ([]string, error) {
|
||||
return s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
}
|
||||
|
||||
// GetConnectorsByTenantID gets connectors for a tenant
|
||||
func (s *KnowledgebaseService) GetConnectorsByTenantID(tenantID string) ([]*dao.ConnectorListItem, error) {
|
||||
return s.connectorDAO.ListByTenantID(tenantID)
|
||||
}
|
||||
|
||||
// GetKBList retrieves knowledge bases with ID and name filtering
|
||||
func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, common.ErrorCode, error) {
|
||||
kbs, total, err := s.kbDAO.GetList(tenantIDs, userID, page, pageSize, orderby, desc, id, name)
|
||||
if err != nil {
|
||||
return nil, 0, common.CodeServerError, err
|
||||
}
|
||||
return kbs, total, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID
|
||||
func (s *KnowledgebaseService) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) {
|
||||
return s.kbDAO.GetKBByIDAndUserID(kbID, userID)
|
||||
}
|
||||
|
||||
// GetKBByNameAndUserID retrieves a knowledge base by name and user ID
|
||||
func (s *KnowledgebaseService) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) {
|
||||
return s.kbDAO.GetKBByNameAndUserID(kbName, userID)
|
||||
}
|
||||
|
||||
// AtomicIncreaseDocNumByID atomically increments the document count
|
||||
func (s *KnowledgebaseService) AtomicIncreaseDocNumByID(kbID string) error {
|
||||
return s.kbDAO.AtomicIncreaseDocNumByID(kbID)
|
||||
}
|
||||
|
||||
// DecreaseDocumentNum decreases document, chunk, and token counts
|
||||
func (s *KnowledgebaseService) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error {
|
||||
return s.kbDAO.DecreaseDocumentNum(kbID, docNum, chunkNum, tokenNum)
|
||||
}
|
||||
|
||||
// UpdateParserConfig updates the parser configuration
|
||||
func (s *KnowledgebaseService) UpdateParserConfig(id string, config map[string]interface{}) error {
|
||||
return s.kbDAO.UpdateParserConfig(id, config)
|
||||
}
|
||||
|
||||
// DeleteFieldMap removes the field_map from parser_config
|
||||
func (s *KnowledgebaseService) DeleteFieldMap(id string) error {
|
||||
return s.kbDAO.DeleteFieldMap(id)
|
||||
}
|
||||
|
||||
// GetFieldMap retrieves field mappings from multiple knowledge bases
|
||||
func (s *KnowledgebaseService) GetFieldMap(ids []string) (map[string]interface{}, error) {
|
||||
return s.kbDAO.GetFieldMap(ids)
|
||||
}
|
||||
|
||||
// GetAllIDs retrieves all knowledge base IDs
|
||||
func (s *KnowledgebaseService) GetAllIDs() ([]string, error) {
|
||||
return s.kbDAO.GetAllIDs()
|
||||
}
|
||||
|
||||
// ExtractAccessToken extracts access token from authorization header
|
||||
func ExtractAccessToken(authorization, secretKey string) (string, error) {
|
||||
return utility.ExtractAccessToken(authorization, secretKey)
|
||||
}
|
||||
|
||||
@ -17,11 +17,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
var DB = dao.DB
|
||||
|
||||
// LLMService LLM service
|
||||
type LLMService struct {
|
||||
tenantLLMDAO *dao.TenantLLMDAO
|
||||
@ -38,6 +43,7 @@ func NewLLMService() *LLMService {
|
||||
|
||||
// MyLLMItem represents a single LLM item in the response
|
||||
type MyLLMItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
UsedToken int64 `json:"used_token"`
|
||||
@ -46,67 +52,89 @@ type MyLLMItem struct {
|
||||
MaxTokens int64 `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// MyLLMResponse represents the response structure for my LLMs
|
||||
type MyLLMResponse struct {
|
||||
// MyLLMFactory represents the response structure for a factory in my LLMs
|
||||
type MyLLMFactory struct {
|
||||
Tags string `json:"tags"`
|
||||
LLM []MyLLMItem `json:"llm"`
|
||||
}
|
||||
|
||||
// GetMyLLMs get my LLMs for a tenant
|
||||
func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMResponse, error) {
|
||||
// Get LLM list from database
|
||||
myLLMs, err := s.tenantLLMDAO.GetMyLLMs(tenantID, includeDetails)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMFactory, error) {
|
||||
result := make(map[string]MyLLMFactory)
|
||||
|
||||
// Group by factory
|
||||
result := make(map[string]MyLLMResponse)
|
||||
providerDAO := dao.NewModelProviderDAO()
|
||||
for _, llm := range myLLMs {
|
||||
// Get or create factory entry
|
||||
resp, exists := result[llm.LLMFactory]
|
||||
if !exists {
|
||||
resp = MyLLMResponse{
|
||||
Tags: llm.Tags,
|
||||
LLM: []MyLLMItem{},
|
||||
if includeDetails {
|
||||
objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
factoryDAO := dao.NewLLMFactoryDAO()
|
||||
factories, err := factoryDAO.GetAllValid()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
factoryTagsMap := make(map[string]string)
|
||||
for _, f := range factories {
|
||||
if f.Tags != "" {
|
||||
factoryTagsMap[f.Name] = f.Tags
|
||||
}
|
||||
}
|
||||
|
||||
// Create LLM item
|
||||
item := MyLLMItem{
|
||||
Type: llm.ModelType,
|
||||
Name: llm.LLMName,
|
||||
UsedToken: llm.UsedTokens,
|
||||
Status: llm.Status,
|
||||
}
|
||||
|
||||
// Add detailed fields if requested
|
||||
if includeDetails {
|
||||
item.APIBase = llm.APIBase
|
||||
item.MaxTokens = llm.MaxTokens
|
||||
|
||||
// If APIBase is empty, try to get from model provider configuration
|
||||
if item.APIBase == "" {
|
||||
provider := providerDAO.GetProviderByName(llm.LLMFactory)
|
||||
if provider != nil {
|
||||
// Determine appropriate API base URL based on model type
|
||||
switch llm.ModelType {
|
||||
case "embedding":
|
||||
if provider.DefaultEmbeddingURL != "" {
|
||||
item.APIBase = provider.DefaultEmbeddingURL
|
||||
}
|
||||
// Add other model types here if needed
|
||||
// case "chat":
|
||||
// case "rerank":
|
||||
// etc.
|
||||
}
|
||||
for _, o := range objs {
|
||||
llmFactory := o.LLMFactory
|
||||
if _, exists := result[llmFactory]; !exists {
|
||||
tags := factoryTagsMap[llmFactory]
|
||||
result[llmFactory] = MyLLMFactory{
|
||||
Tags: tags,
|
||||
LLM: []MyLLMItem{},
|
||||
}
|
||||
}
|
||||
|
||||
item := MyLLMItem{
|
||||
ID: int64ToString(o.ID),
|
||||
Type: getStringValue(o.ModelType),
|
||||
Name: getStringValue(o.LLMName),
|
||||
UsedToken: o.UsedTokens,
|
||||
Status: getValidStatus(o.Status),
|
||||
}
|
||||
|
||||
if includeDetails {
|
||||
item.APIBase = getStringValueDefault(o.APIBase, "")
|
||||
item.MaxTokens = o.MaxTokens
|
||||
}
|
||||
|
||||
factory := result[llmFactory]
|
||||
factory.LLM = append(factory.LLM, item)
|
||||
result[llmFactory] = factory
|
||||
}
|
||||
} else {
|
||||
objs, err := s.tenantLLMDAO.GetMyLLMs(tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.LLM = append(resp.LLM, item)
|
||||
result[llm.LLMFactory] = resp
|
||||
for _, o := range objs {
|
||||
llmFactory := o.LLMFactory
|
||||
if _, exists := result[llmFactory]; !exists {
|
||||
result[llmFactory] = MyLLMFactory{
|
||||
Tags: getStringValue(o.Tags),
|
||||
LLM: []MyLLMItem{},
|
||||
}
|
||||
}
|
||||
|
||||
item := MyLLMItem{
|
||||
ID: o.ID,
|
||||
Type: getStringValue(o.ModelType),
|
||||
Name: getStringValue(o.LLMName),
|
||||
UsedToken: getInt64Value(o.UsedTokens),
|
||||
Status: getStringValueDefault(o.Status, "1"),
|
||||
}
|
||||
|
||||
factory := result[llmFactory]
|
||||
factory.LLM = append(factory.LLM, item)
|
||||
result[llmFactory] = factory
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@ -114,6 +142,7 @@ func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string
|
||||
|
||||
// LLMListItem represents a single LLM item in the list response
|
||||
type LLMListItem struct {
|
||||
ID string `json:"id"`
|
||||
LLMName string `json:"llm_name"`
|
||||
ModelType string `json:"model_type"`
|
||||
FID string `json:"fid"`
|
||||
@ -142,37 +171,32 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
|
||||
"GPUStack": true,
|
||||
}
|
||||
|
||||
// Get tenant LLMs
|
||||
tenantLLMs, err := s.tenantLLMDAO.ListAllByTenant(tenantID)
|
||||
objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build set of factories with valid API keys
|
||||
facts := make(map[string]bool)
|
||||
// Build set of valid LLM names@factories
|
||||
status := make(map[string]bool)
|
||||
for _, tl := range tenantLLMs {
|
||||
if tl.APIKey != nil && *tl.APIKey != "" && tl.Status == "1" {
|
||||
facts[tl.LLMFactory] = true
|
||||
tenantLLMMapping := make(map[string]string)
|
||||
|
||||
for _, o := range objs {
|
||||
if o.APIKey != nil && *o.APIKey != "" && getValidStatus(o.Status) == "1" {
|
||||
facts[o.LLMFactory] = true
|
||||
}
|
||||
llmName := ""
|
||||
if tl.LLMName != nil {
|
||||
llmName = *tl.LLMName
|
||||
}
|
||||
key := llmName + "@" + tl.LLMFactory
|
||||
if tl.Status == "1" {
|
||||
llmName := getStringValue(o.LLMName)
|
||||
key := llmName + "@" + o.LLMFactory
|
||||
if getValidStatus(o.Status) == "1" {
|
||||
status[key] = true
|
||||
}
|
||||
tenantLLMMapping[key] = int64ToString(o.ID)
|
||||
}
|
||||
|
||||
// Get all valid LLMs
|
||||
allLLMs, err := s.llmDAO.GetAllValid()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter and build result
|
||||
llmSet := make(map[string]bool)
|
||||
result := make(ListLLMsResponse)
|
||||
|
||||
@ -183,20 +207,18 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
|
||||
|
||||
key := llm.LLMName + "@" + llm.FID
|
||||
|
||||
// Check if valid (Builtin factory or in status set)
|
||||
if llm.FID != "Builtin" && !status[key] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by model type if specified
|
||||
if modelType != "" && !strings.Contains(llm.ModelType, modelType) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine availability
|
||||
available := facts[llm.FID] || selfDeployed[llm.FID] || llm.LLMName == "flag-embedding"
|
||||
available := facts[llm.FID] || selfDeployed[llm.FID] || strings.ToLower(llm.LLMName) == "flag-embedding"
|
||||
|
||||
item := LLMListItem{
|
||||
ID: tenantLLMMapping[key],
|
||||
LLMName: llm.LLMName,
|
||||
ModelType: llm.ModelType,
|
||||
FID: llm.FID,
|
||||
@ -207,7 +229,6 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
|
||||
Tags: llm.Tags,
|
||||
}
|
||||
|
||||
// Add BaseModel fields
|
||||
if llm.CreateDate != nil {
|
||||
createDateStr := llm.CreateDate.Format("2006-01-02T15:04:05")
|
||||
item.CreateDate = &createDateStr
|
||||
@ -225,36 +246,160 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon
|
||||
llmSet[key] = true
|
||||
}
|
||||
|
||||
// Add tenant LLMs that are not in the global list
|
||||
for _, tl := range tenantLLMs {
|
||||
llmName := ""
|
||||
if tl.LLMName != nil {
|
||||
llmName = *tl.LLMName
|
||||
}
|
||||
key := llmName + "@" + tl.LLMFactory
|
||||
for _, o := range objs {
|
||||
llmName := getStringValue(o.LLMName)
|
||||
key := llmName + "@" + o.LLMFactory
|
||||
if llmSet[key] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by model type if specified
|
||||
modelTypeValue := ""
|
||||
if tl.ModelType != nil {
|
||||
modelTypeValue = *tl.ModelType
|
||||
}
|
||||
modelTypeValue := getStringValue(o.ModelType)
|
||||
if modelType != "" && !strings.Contains(modelTypeValue, modelType) {
|
||||
continue
|
||||
}
|
||||
|
||||
item := LLMListItem{
|
||||
ID: int64ToString(o.ID),
|
||||
LLMName: llmName,
|
||||
ModelType: modelTypeValue,
|
||||
FID: tl.LLMFactory,
|
||||
FID: o.LLMFactory,
|
||||
Available: true,
|
||||
Status: tl.Status,
|
||||
Status: getValidStatus(o.Status),
|
||||
}
|
||||
|
||||
result[tl.LLMFactory] = append(result[tl.LLMFactory], item)
|
||||
result[o.LLMFactory] = append(result[o.LLMFactory], item)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func getStringValue(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
func getStringValueDefault(s *string, defaultVal string) string {
|
||||
if s == nil || *s == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
func getValidStatus(status string) string {
|
||||
if status == "" {
|
||||
return "1"
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
func getInt64Value(i *int64) int64 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
return *i
|
||||
}
|
||||
|
||||
func getInt64ValueDefault(i *int64, defaultVal int64) int64 {
|
||||
if i == nil || *i == 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return *i
|
||||
}
|
||||
|
||||
func getBoolValue(b *bool) bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return *b
|
||||
}
|
||||
|
||||
func int64ToString(n int64) string {
|
||||
return strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
// SetAPIKeyRequest represents the request for setting API key
|
||||
type SetAPIKeyRequest struct {
|
||||
LLMFactory string `json:"llm_factory"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
SourceFID string `json:"source_fid"`
|
||||
ModelType string `json:"model_type"`
|
||||
LLMName string `json:"llm_name"`
|
||||
Verify bool `json:"verify"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
}
|
||||
|
||||
// SetAPIKeyResult represents the result of setting API key
|
||||
type SetAPIKeyResult struct {
|
||||
Message string `json:"message"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// SetAPIKey sets API key for a LLM factory
|
||||
func (s *LLMService) SetAPIKey(tenantID string, req *SetAPIKeyRequest) (*SetAPIKeyResult, error) {
|
||||
factory := req.LLMFactory
|
||||
baseURL := req.BaseURL
|
||||
sourceFactory := req.SourceFID
|
||||
if sourceFactory == "" {
|
||||
sourceFactory = factory
|
||||
}
|
||||
|
||||
sourceLLMs, err := s.llmDAO.GetByFactory(sourceFactory)
|
||||
if err != nil || len(sourceLLMs) == 0 {
|
||||
msg := "No models configured for " + factory + " (source: " + sourceFactory + ")."
|
||||
if req.Verify {
|
||||
return &SetAPIKeyResult{Message: msg, Success: false}, nil
|
||||
}
|
||||
return nil, fmt.Errorf(msg)
|
||||
}
|
||||
|
||||
llmConfig := map[string]interface{}{
|
||||
"api_key": req.APIKey,
|
||||
"api_base": baseURL,
|
||||
}
|
||||
|
||||
if req.ModelType != "" {
|
||||
llmConfig["model_type"] = req.ModelType
|
||||
}
|
||||
if req.LLMName != "" {
|
||||
llmConfig["llm_name"] = req.LLMName
|
||||
}
|
||||
|
||||
for _, llm := range sourceLLMs {
|
||||
maxTokens := llm.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 8192
|
||||
}
|
||||
llmConfig["max_tokens"] = maxTokens
|
||||
|
||||
existingLLM, _ := s.tenantLLMDAO.GetByTenantFactoryAndModelName(tenantID, factory, llm.LLMName)
|
||||
if existingLLM != nil {
|
||||
updates := map[string]interface{}{
|
||||
"api_key": req.APIKey,
|
||||
"api_base": baseURL,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
DB.Model(&model.TenantLLM{}).
|
||||
Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, llm.LLMName).
|
||||
Updates(updates)
|
||||
} else {
|
||||
modelType := llm.ModelType
|
||||
llmName := llm.LLMName
|
||||
tenantLLM := &model.TenantLLM{
|
||||
TenantID: tenantID,
|
||||
LLMFactory: factory,
|
||||
ModelType: &modelType,
|
||||
LLMName: &llmName,
|
||||
APIKey: &req.APIKey,
|
||||
APIBase: &baseURL,
|
||||
MaxTokens: maxTokens,
|
||||
Status: "1",
|
||||
}
|
||||
s.tenantLLMDAO.Create(tenantLLM)
|
||||
}
|
||||
}
|
||||
|
||||
return &SetAPIKeyResult{Message: "", Success: true}, nil
|
||||
}
|
||||
|
||||
@ -151,15 +151,38 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
user.LastLoginTime = &now_date
|
||||
|
||||
tenantName := req.Nickname + "'s Kingdom"
|
||||
|
||||
llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name
|
||||
if llmID == "" {
|
||||
llmID = ""
|
||||
}
|
||||
embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name
|
||||
if embdID == "" {
|
||||
embdID = ""
|
||||
}
|
||||
asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name
|
||||
if asrID == "" {
|
||||
asrID = ""
|
||||
}
|
||||
img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name
|
||||
if img2txtID == "" {
|
||||
img2txtID = ""
|
||||
}
|
||||
rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name
|
||||
if rerankID == "" {
|
||||
rerankID = ""
|
||||
}
|
||||
|
||||
tenant := &model.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
LLMID: cfg.Server.Mode,
|
||||
EmbdID: cfg.Server.Mode,
|
||||
ASRID: cfg.Server.Mode,
|
||||
Img2TxtID: cfg.Server.Mode,
|
||||
RerankID: cfg.Server.Mode,
|
||||
LLMID: llmID,
|
||||
EmbdID: embdID,
|
||||
ASRID: asrID,
|
||||
Img2TxtID: img2txtID,
|
||||
RerankID: rerankID,
|
||||
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
|
||||
Status: &status,
|
||||
}
|
||||
tenant.CreateTime = &now
|
||||
tenant.UpdateTime = &now
|
||||
@ -753,3 +776,52 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, err
|
||||
|
||||
return channels, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// SetTenantInfoRequest represents the request for setting tenant info
|
||||
type SetTenantInfoRequest struct {
|
||||
TenantID string `json:"tenant_id"`
|
||||
ASRID string `json:"asr_id"`
|
||||
EmbdID string `json:"embd_id"`
|
||||
Img2TxtID string `json:"img2txt_id"`
|
||||
LLMID string `json:"llm_id"`
|
||||
RerankID string `json:"rerank_id"`
|
||||
TTSID string `json:"tts_id"`
|
||||
}
|
||||
|
||||
// SetTenantInfo updates tenant model configuration
|
||||
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error {
|
||||
tenantDAO := dao.NewTenantDAO()
|
||||
|
||||
_, err := tenantDAO.GetByID(req.TenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tenant not found: %w", err)
|
||||
}
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
if req.LLMID != "" {
|
||||
updates["llm_id"] = req.LLMID
|
||||
}
|
||||
if req.EmbdID != "" {
|
||||
updates["embd_id"] = req.EmbdID
|
||||
}
|
||||
if req.ASRID != "" {
|
||||
updates["asr_id"] = req.ASRID
|
||||
}
|
||||
if req.Img2TxtID != "" {
|
||||
updates["img2txt_id"] = req.Img2TxtID
|
||||
}
|
||||
if req.RerankID != "" {
|
||||
updates["rerank_id"] = req.RerankID
|
||||
}
|
||||
if req.TTSID != "" {
|
||||
updates["tts_id"] = req.TTSID
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := tenantDAO.Update(req.TenantID, updates); err != nil {
|
||||
return fmt.Errorf("failed to update tenant: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user