mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-04 09:17:48 +08:00
Add rename model directory to entity to avoid name misunderstanding (#13829)
### What problem does this PR solve? Model-> entity ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -29,8 +29,9 @@ import (
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine/elasticsearch"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/logger"
|
||||
"ragflow/internal/model"
|
||||
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
"regexp"
|
||||
@ -94,7 +95,7 @@ func NewService() *Service {
|
||||
// Logout user logout
|
||||
func (s *Service) Logout(user interface{}) error {
|
||||
// Invalidate token by setting it to INVALID_ prefix
|
||||
if u, ok := user.(*model.User); ok {
|
||||
if u, ok := user.(*entity.User); ok {
|
||||
invalidToken := "INVALID_" + generateRandomHex(16)
|
||||
return s.userDAO.UpdateAccessToken(u, invalidToken)
|
||||
}
|
||||
@ -102,7 +103,7 @@ func (s *Service) Logout(user interface{}) error {
|
||||
}
|
||||
|
||||
// GetUserByToken get user by access token
|
||||
func (s *Service) GetUserByToken(token string) (*model.User, error) {
|
||||
func (s *Service) GetUserByToken(token string) (*entity.User, error) {
|
||||
user, err := s.userDAO.GetByAccessToken(token)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
@ -185,7 +186,7 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
|
||||
user := &model.User{
|
||||
user := &entity.User{
|
||||
ID: userID,
|
||||
AccessToken: &accessToken,
|
||||
Email: username,
|
||||
@ -197,7 +198,7 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
IsAnonymous: "0",
|
||||
LoginChannel: &loginChannel,
|
||||
IsSuperuser: &isSuperuser,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -246,7 +247,7 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
}
|
||||
|
||||
tenantStatus := "1"
|
||||
tenant := &model.Tenant{
|
||||
tenant := &entity.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
LLMID: chatMdl,
|
||||
@ -257,7 +258,7 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
ParserIDs: parserIDs,
|
||||
Credit: 512,
|
||||
Status: &tenantStatus,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -271,14 +272,14 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
|
||||
// 3. Create user-tenant relation
|
||||
userTenantStatus := "1"
|
||||
userTenant := &model.UserTenant{
|
||||
userTenant := &entity.UserTenant{
|
||||
ID: utility.GenerateToken(),
|
||||
UserID: userID,
|
||||
TenantID: userID,
|
||||
Role: "owner",
|
||||
InvitedBy: userID,
|
||||
Status: &userTenantStatus,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -305,7 +306,7 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
// 5. Create root file folder
|
||||
fileID := utility.GenerateToken()
|
||||
fileLocation := ""
|
||||
file := &model.File{
|
||||
file := &entity.File{
|
||||
ID: fileID,
|
||||
ParentID: fileID,
|
||||
TenantID: userID,
|
||||
@ -314,7 +315,7 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
Type: "folder",
|
||||
Size: 0,
|
||||
Location: &fileLocation,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -345,13 +346,13 @@ func (s *Service) CreateUser(username, password, role string) (map[string]interf
|
||||
|
||||
// getInitTenantLLM gets initial tenant LLM configurations
|
||||
// This matches Python's get_init_tenant_llm function
|
||||
func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
func (s *Service) getInitTenantLLM(userID string) ([]*entity.TenantLLM, error) {
|
||||
cfg := server.GetConfig()
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config not initialized")
|
||||
}
|
||||
|
||||
var tenantLLMs []*model.TenantLLM
|
||||
var tenantLLMs []*entity.TenantLLM
|
||||
|
||||
// Get model configs from configuration
|
||||
modelConfigs := []server.ModelConfig{
|
||||
@ -388,10 +389,10 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
// Determine API key and base URL based on model type
|
||||
var apiKey, apiBase string
|
||||
switch llm.ModelType {
|
||||
case string(model.ModelTypeChat):
|
||||
case string(entity.ModelTypeChat):
|
||||
apiKey = factoryConfig.APIKey
|
||||
apiBase = factoryConfig.BaseURL
|
||||
case string(model.ModelTypeEmbedding):
|
||||
case string(entity.ModelTypeEmbedding):
|
||||
apiKey = cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.APIKey
|
||||
apiBase = cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.BaseURL
|
||||
if apiKey == "" {
|
||||
@ -400,7 +401,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = factoryConfig.BaseURL
|
||||
}
|
||||
case string(model.ModelTypeRerank):
|
||||
case string(entity.ModelTypeRerank):
|
||||
apiKey = cfg.UserDefaultLLM.DefaultModels.RerankModel.APIKey
|
||||
apiBase = cfg.UserDefaultLLM.DefaultModels.RerankModel.BaseURL
|
||||
if apiKey == "" {
|
||||
@ -409,7 +410,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = factoryConfig.BaseURL
|
||||
}
|
||||
case string(model.ModelTypeSpeech2Text):
|
||||
case string(entity.ModelTypeSpeech2Text):
|
||||
apiKey = cfg.UserDefaultLLM.DefaultModels.ASRModel.APIKey
|
||||
apiBase = cfg.UserDefaultLLM.DefaultModels.ASRModel.BaseURL
|
||||
if apiKey == "" {
|
||||
@ -418,7 +419,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = factoryConfig.BaseURL
|
||||
}
|
||||
case string(model.ModelTypeImage2Text):
|
||||
case string(entity.ModelTypeImage2Text):
|
||||
apiKey = cfg.UserDefaultLLM.DefaultModels.Image2TextModel.APIKey
|
||||
apiBase = cfg.UserDefaultLLM.DefaultModels.Image2TextModel.BaseURL
|
||||
if apiKey == "" {
|
||||
@ -442,7 +443,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
|
||||
tenantLLM := &model.TenantLLM{
|
||||
tenantLLM := &entity.TenantLLM{
|
||||
TenantID: userID,
|
||||
LLMFactory: factoryConfig.Factory,
|
||||
LLMName: &llmName,
|
||||
@ -451,7 +452,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
APIBase: &apiBase,
|
||||
MaxTokens: maxTokens,
|
||||
Status: "1",
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -464,7 +465,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
|
||||
// Remove duplicates based on (tenant_id, llm_factory, llm_name)
|
||||
seen := make(map[string]bool)
|
||||
var uniqueLLMs []*model.TenantLLM
|
||||
var uniqueLLMs []*entity.TenantLLM
|
||||
for _, tllm := range tenantLLMs {
|
||||
key := fmt.Sprintf("%s|%s|%s", tllm.TenantID, tllm.LLMFactory, *tllm.LLMName)
|
||||
if !seen[key] {
|
||||
@ -479,7 +480,7 @@ func (s *Service) getInitTenantLLM(userID string) ([]*model.TenantLLM, error) {
|
||||
// GetUserDetails get user details
|
||||
func (s *Service) GetUserDetails(username string) (map[string]interface{}, error) {
|
||||
// Query user by email/username
|
||||
var user model.User
|
||||
var user entity.User
|
||||
err := dao.DB.Where("email = ?", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
@ -590,64 +591,64 @@ func (s *Service) DeleteUser(username string) (*DeleteUserResult, error) {
|
||||
for i, d := range docIDs {
|
||||
docIDList[i] = d["id"]
|
||||
}
|
||||
if delErr := tx.Unscoped().Where("doc_id IN ?", docIDList).Delete(&model.Task{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("doc_id IN ?", docIDList).Delete(&entity.Task{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete tasks", zap.Error(delErr.Error))
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Delete documents
|
||||
if delErr := tx.Unscoped().Where("kb_id IN ?", kbIDs).Delete(&model.Document{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("kb_id IN ?", kbIDs).Delete(&entity.Document{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete documents", zap.Error(delErr.Error))
|
||||
}
|
||||
|
||||
// 5. Delete knowledge bases
|
||||
if delErr := tx.Unscoped().Where("id IN ?", kbIDs).Delete(&model.Knowledgebase{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("id IN ?", kbIDs).Delete(&entity.Knowledgebase{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete knowledge bases", zap.Error(delErr.Error))
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Delete files
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&model.File{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&entity.File{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete files", zap.Error(delErr.Error))
|
||||
}
|
||||
|
||||
// 7. Delete user canvas (agents)
|
||||
if delErr := tx.Unscoped().Where("user_id = ?", ownedTenantID).Delete(&model.UserCanvas{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("user_id = ?", ownedTenantID).Delete(&entity.UserCanvas{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete user canvas", zap.Error(delErr.Error))
|
||||
}
|
||||
|
||||
// 8. Get dialog IDs
|
||||
var dialogIDs []string
|
||||
if pluckErr := tx.Model(&model.Chat{}).Where("tenant_id = ?", ownedTenantID).Pluck("id", &dialogIDs); pluckErr.Error != nil {
|
||||
if pluckErr := tx.Model(&entity.Chat{}).Where("tenant_id = ?", ownedTenantID).Pluck("id", &dialogIDs); pluckErr.Error != nil {
|
||||
logger.Warn("failed to get dialog IDs", zap.Error(pluckErr.Error))
|
||||
}
|
||||
|
||||
// 9. Delete chat sessions
|
||||
if len(dialogIDs) > 0 {
|
||||
if delErr := tx.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&model.ChatSession{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.ChatSession{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete chat sessions", zap.Error(delErr.Error))
|
||||
}
|
||||
}
|
||||
|
||||
// 10. Delete chats/dialogs
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&model.Chat{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&entity.Chat{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete chats", zap.Error(delErr.Error))
|
||||
}
|
||||
|
||||
// 11. Delete API tokens
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&model.APIToken{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&entity.APIToken{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete API tokens", zap.Error(delErr.Error))
|
||||
}
|
||||
|
||||
// 12. Delete API4Conversations
|
||||
if len(dialogIDs) > 0 {
|
||||
if delErr := tx.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&model.API4Conversation{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.API4Conversation{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete API4Conversations", zap.Error(delErr.Error))
|
||||
}
|
||||
}
|
||||
|
||||
var tenantLLMCount int64
|
||||
tx.Model(&model.TenantLLM{}).Where("tenant_id = ?", ownedTenantID).Count(&tenantLLMCount)
|
||||
tx.Model(&entity.TenantLLM{}).Where("tenant_id = ?", ownedTenantID).Count(&tenantLLMCount)
|
||||
result.TenantLLMCount = int(tenantLLMCount)
|
||||
result.DeletedDetails = append(result.DeletedDetails, fmt.Sprintf("- Deleted %d tenant-LLM records.", tenantLLMCount))
|
||||
|
||||
@ -659,33 +660,33 @@ func (s *Service) DeleteUser(username string) (*DeleteUserResult, error) {
|
||||
result.DeletedDetails = append(result.DeletedDetails, fmt.Sprintf("- Deleted metadata table %s.", metadataTableName))
|
||||
|
||||
// 13. Delete tenant LLM configurations
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&model.TenantLLM{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("tenant_id = ?", ownedTenantID).Delete(&entity.TenantLLM{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete tenant LLM", zap.Error(delErr.Error))
|
||||
}
|
||||
|
||||
var tenantCount int64
|
||||
tx.Model(&model.Tenant{}).Where("id = ?", ownedTenantID).Count(&tenantCount)
|
||||
tx.Model(&entity.Tenant{}).Where("id = ?", ownedTenantID).Count(&tenantCount)
|
||||
result.TenantCount = int(tenantCount)
|
||||
// 14. Delete tenant
|
||||
if delErr := tx.Unscoped().Where("id = ?", ownedTenantID).Delete(&model.Tenant{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("id = ?", ownedTenantID).Delete(&entity.Tenant{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete tenant", zap.Error(delErr.Error))
|
||||
}
|
||||
result.DeletedDetails = append(result.DeletedDetails, fmt.Sprintf("- Deleted %d tenant.", result.TenantCount))
|
||||
}
|
||||
|
||||
var userTenantCount int64
|
||||
tx.Model(&model.UserTenant{}).Where("user_id = ?", user.ID).Count(&userTenantCount)
|
||||
tx.Model(&entity.UserTenant{}).Where("user_id = ?", user.ID).Count(&userTenantCount)
|
||||
result.UserTenantCount = int(userTenantCount)
|
||||
|
||||
// 15. Delete user-tenant relations
|
||||
if delErr := tx.Unscoped().Where("user_id = ?", user.ID).Delete(&model.UserTenant{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("user_id = ?", user.ID).Delete(&entity.UserTenant{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete user-tenant relations", zap.Error(delErr.Error))
|
||||
}
|
||||
result.DeletedDetails = append(result.DeletedDetails, fmt.Sprintf("- Deleted %d user-tenant records.", result.UserTenantCount))
|
||||
|
||||
result.UserCount = 1
|
||||
// 16. Finally, hard delete user
|
||||
if delErr := tx.Unscoped().Where("id = ?", user.ID).Delete(&model.User{}); delErr.Error != nil {
|
||||
if delErr := tx.Unscoped().Where("id = ?", user.ID).Delete(&entity.User{}); delErr.Error != nil {
|
||||
rollbackTx()
|
||||
return nil, fmt.Errorf("failed to delete user: %w", delErr.Error)
|
||||
}
|
||||
@ -931,7 +932,7 @@ func (s *Service) GenerateUserAPIToken(username string) (map[string]interface{},
|
||||
now := time.Now()
|
||||
nowUnix := now.Unix()
|
||||
|
||||
apiToken := &model.APIToken{
|
||||
apiToken := &entity.APIToken{
|
||||
TenantID: tenantID,
|
||||
Token: key,
|
||||
Beta: &beta,
|
||||
@ -1539,7 +1540,7 @@ func (s *Service) SetVariable(varName, varValue string) error {
|
||||
dataType = "boolean"
|
||||
}
|
||||
|
||||
newSetting := &model.SystemSettings{
|
||||
newSetting := &entity.SystemSettings{
|
||||
Name: varName,
|
||||
Value: varValue,
|
||||
Source: "admin",
|
||||
@ -1692,7 +1693,7 @@ func (s *Service) InitDefaultAdmin() error {
|
||||
defaultPassword := "admin"
|
||||
|
||||
// Query superusers
|
||||
var users []*model.User
|
||||
var users []*entity.User
|
||||
err := dao.DB.Where("is_superuser = ? AND status = ?", true, "1").Find(&users).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query superusers: %w", err)
|
||||
@ -1715,7 +1716,7 @@ func (s *Service) InitDefaultAdmin() error {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
user := &model.User{
|
||||
user := &entity.User{
|
||||
ID: userID,
|
||||
Email: defaultEmail,
|
||||
Nickname: defaultNickname,
|
||||
@ -1727,7 +1728,7 @@ func (s *Service) InitDefaultAdmin() error {
|
||||
IsAnonymous: "0",
|
||||
LoginChannel: &loginChannel,
|
||||
IsSuperuser: &isSuperuser,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -1756,7 +1757,7 @@ func (s *Service) InitDefaultAdmin() error {
|
||||
if user.Email == defaultEmail {
|
||||
// Check if tenant exists
|
||||
var count int64
|
||||
dao.DB.Model(&model.UserTenant{}).Where("user_id = ? AND status = ?", user.ID, "1").Count(&count)
|
||||
dao.DB.Model(&entity.UserTenant{}).Where("user_id = ? AND status = ?", user.ID, "1").Count(&count)
|
||||
if count == 0 {
|
||||
nickname := defaultNickname
|
||||
if user.Nickname != "" {
|
||||
@ -1781,10 +1782,10 @@ func (s *Service) addTenantForAdmin(userID, nickname string) error {
|
||||
role := "owner"
|
||||
tenantName := nickname + "'s Kingdom"
|
||||
|
||||
tenant := &model.Tenant{
|
||||
tenant := &entity.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
@ -1796,13 +1797,13 @@ func (s *Service) addTenantForAdmin(userID, nickname string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
userTenant := &model.UserTenant{
|
||||
userTenant := &entity.UserTenant{
|
||||
TenantID: userID,
|
||||
UserID: userID,
|
||||
InvitedBy: userID,
|
||||
Role: role,
|
||||
Status: &status,
|
||||
BaseModel: model.BaseModel{
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// APITokenDAO API token data access object
|
||||
@ -29,26 +29,26 @@ func NewAPITokenDAO() *APITokenDAO {
|
||||
}
|
||||
|
||||
// Create creates a new API token
|
||||
func (dao *APITokenDAO) Create(apiToken *model.APIToken) error {
|
||||
func (dao *APITokenDAO) Create(apiToken *entity.APIToken) error {
|
||||
return DB.Create(apiToken).Error
|
||||
}
|
||||
|
||||
// GetByTenantID gets API tokens by tenant ID
|
||||
func (dao *APITokenDAO) GetByTenantID(tenantID string) ([]*model.APIToken, error) {
|
||||
var tokens []*model.APIToken
|
||||
func (dao *APITokenDAO) GetByTenantID(tenantID string) ([]*entity.APIToken, error) {
|
||||
var tokens []*entity.APIToken
|
||||
err := DB.Where("tenant_id = ?", tenantID).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
// DeleteByTenantID deletes all API tokens by tenant ID (hard delete)
|
||||
func (dao *APITokenDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.APIToken{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.APIToken{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetByToken gets API token by access key
|
||||
func (dao *APITokenDAO) GetUserByAPIToken(token string) (*model.APIToken, error) {
|
||||
var apiToken model.APIToken
|
||||
func (dao *APITokenDAO) GetUserByAPIToken(token string) (*entity.APIToken, error) {
|
||||
var apiToken entity.APIToken
|
||||
err := DB.Where("token = ?", token).First(&apiToken).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -61,13 +61,13 @@ func (dao *APITokenDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) {
|
||||
if len(dialogIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&model.APIToken{})
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.APIToken{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteByTenantIDAndToken deletes a specific API token by tenant ID and token value
|
||||
func (dao *APITokenDAO) DeleteByTenantIDAndToken(tenantID, token string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ? AND token = ?", tenantID, token).Delete(&model.APIToken{})
|
||||
result := DB.Unscoped().Where("tenant_id = ? AND token = ?", tenantID, token).Delete(&entity.APIToken{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@ -84,6 +84,6 @@ func (dao *API4ConversationDAO) DeleteByDialogIDs(dialogIDs []string) (int64, er
|
||||
if len(dialogIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&model.API4Conversation{})
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.API4Conversation{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@ -18,9 +18,8 @@ package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// ChatDAO chat data access object
|
||||
@ -32,10 +31,10 @@ func NewChatDAO() *ChatDAO {
|
||||
}
|
||||
|
||||
// ListByTenantID list chats by tenant ID
|
||||
func (dao *ChatDAO) ListByTenantID(tenantID string, status string) ([]*model.Chat, error) {
|
||||
var chats []*model.Chat
|
||||
func (dao *ChatDAO) ListByTenantID(tenantID string, status string) ([]*entity.Chat, error) {
|
||||
var chats []*entity.Chat
|
||||
|
||||
query := DB.Model(&model.Chat{}).
|
||||
query := DB.Model(&entity.Chat{}).
|
||||
Where("tenant_id = ?", tenantID)
|
||||
|
||||
if status != "" {
|
||||
@ -51,12 +50,12 @@ func (dao *ChatDAO) ListByTenantID(tenantID string, status string) ([]*model.Cha
|
||||
}
|
||||
|
||||
// ListByTenantIDs list chats by tenant IDs with pagination and filtering
|
||||
func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*model.Chat, int64, error) {
|
||||
var chats []*model.Chat
|
||||
func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*entity.Chat, int64, error) {
|
||||
var chats []*entity.Chat
|
||||
var total int64
|
||||
|
||||
// Build query with join to user table for nickname and avatar
|
||||
query := DB.Model(&model.Chat{}).
|
||||
query := DB.Model(&entity.Chat{}).
|
||||
Select(`
|
||||
dialog.*,
|
||||
user.nickname,
|
||||
@ -103,11 +102,11 @@ func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pag
|
||||
}
|
||||
|
||||
// ListByOwnerIDs list chats by owner IDs with filtering (manual pagination)
|
||||
func (dao *ChatDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby string, desc bool, keywords string) ([]*model.Chat, int64, error) {
|
||||
var chats []*model.Chat
|
||||
func (dao *ChatDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby string, desc bool, keywords string) ([]*entity.Chat, int64, error) {
|
||||
var chats []*entity.Chat
|
||||
|
||||
// Build query with join to user table
|
||||
query := DB.Model(&model.Chat{}).
|
||||
query := DB.Model(&entity.Chat{}).
|
||||
Select(`
|
||||
dialog.*,
|
||||
user.nickname,
|
||||
@ -142,8 +141,8 @@ func (dao *ChatDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby str
|
||||
}
|
||||
|
||||
// GetByID gets chat by ID
|
||||
func (dao *ChatDAO) GetByID(id string) (*model.Chat, error) {
|
||||
var chat model.Chat
|
||||
func (dao *ChatDAO) GetByID(id string) (*entity.Chat, error) {
|
||||
var chat entity.Chat
|
||||
err := DB.Where("id = ?", id).First(&chat).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -152,8 +151,8 @@ func (dao *ChatDAO) GetByID(id string) (*model.Chat, error) {
|
||||
}
|
||||
|
||||
// GetByIDAndStatus gets chat by ID and status
|
||||
func (dao *ChatDAO) GetByIDAndStatus(id string, status string) (*model.Chat, error) {
|
||||
var chat model.Chat
|
||||
func (dao *ChatDAO) GetByIDAndStatus(id string, status string) (*entity.Chat, error) {
|
||||
var chat entity.Chat
|
||||
err := DB.Where("id = ? AND status = ?", id, status).First(&chat).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -164,20 +163,20 @@ func (dao *ChatDAO) GetByIDAndStatus(id string, status string) (*model.Chat, err
|
||||
// GetExistingNames gets existing dialog names for a tenant
|
||||
func (dao *ChatDAO) GetExistingNames(tenantID string, status string) ([]string, error) {
|
||||
var names []string
|
||||
err := DB.Model(&model.Chat{}).
|
||||
err := DB.Model(&entity.Chat{}).
|
||||
Where("tenant_id = ? AND status = ?", tenantID, status).
|
||||
Pluck("name", &names).Error
|
||||
return names, err
|
||||
}
|
||||
|
||||
// Create creates a new chat/dialog
|
||||
func (dao *ChatDAO) Create(chat *model.Chat) error {
|
||||
func (dao *ChatDAO) Create(chat *entity.Chat) error {
|
||||
return DB.Create(chat).Error
|
||||
}
|
||||
|
||||
// UpdateByID updates a chat by ID
|
||||
func (dao *ChatDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.Chat{}).Where("id = ?", id).Updates(updates).Error
|
||||
return DB.Model(&entity.Chat{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateManyByID updates multiple chats by ID (batch update)
|
||||
@ -207,7 +206,7 @@ func (dao *ChatDAO) UpdateManyByID(updates []map[string]interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(&model.Chat{}).Where("id = ?", id).Updates(updatesWithoutID).Error; err != nil {
|
||||
if err := tx.Model(&entity.Chat{}).Where("id = ?", id).Updates(updatesWithoutID).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
@ -218,14 +217,14 @@ func (dao *ChatDAO) UpdateManyByID(updates []map[string]interface{}) error {
|
||||
|
||||
// DeleteByTenantID deletes all chats by tenant ID (hard delete)
|
||||
func (dao *ChatDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.Chat{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.Chat{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetAllDialogIDsByTenantID gets all dialog IDs by tenant ID
|
||||
func (dao *ChatDAO) GetAllDialogIDsByTenantID(tenantID string) ([]string, error) {
|
||||
var dialogIDs []string
|
||||
err := DB.Model(&model.Chat{}).
|
||||
err := DB.Model(&entity.Chat{}).
|
||||
Where("tenant_id = ?", tenantID).
|
||||
Pluck("id", &dialogIDs).Error
|
||||
return dialogIDs, err
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// ChatSessionDAO chat session data access object
|
||||
@ -29,8 +29,8 @@ func NewChatSessionDAO() *ChatSessionDAO {
|
||||
}
|
||||
|
||||
// GetByID gets chat session by ID
|
||||
func (dao *ChatSessionDAO) GetByID(id string) (*model.ChatSession, error) {
|
||||
var conv model.ChatSession
|
||||
func (dao *ChatSessionDAO) GetByID(id string) (*entity.ChatSession, error) {
|
||||
var conv entity.ChatSession
|
||||
err := DB.Where("id = ?", id).First(&conv).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -39,23 +39,23 @@ func (dao *ChatSessionDAO) GetByID(id string) (*model.ChatSession, error) {
|
||||
}
|
||||
|
||||
// Create creates a new chat session
|
||||
func (dao *ChatSessionDAO) Create(conv *model.ChatSession) error {
|
||||
func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
|
||||
return DB.Create(conv).Error
|
||||
}
|
||||
|
||||
// UpdateByID updates a chat session by ID
|
||||
func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.ChatSession{}).Where("id = ?", id).Updates(updates).Error
|
||||
return DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteByID deletes a chat session by ID (hard delete)
|
||||
func (dao *ChatSessionDAO) DeleteByID(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&model.ChatSession{}).Error
|
||||
return DB.Where("id = ?", id).Delete(&entity.ChatSession{}).Error
|
||||
}
|
||||
|
||||
// ListByDialogID lists chat sessions by dialog ID
|
||||
func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*model.ChatSession, error) {
|
||||
var convs []*model.ChatSession
|
||||
func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*entity.ChatSession, error) {
|
||||
var convs []*entity.ChatSession
|
||||
err := DB.Where("dialog_id = ?", dialogID).
|
||||
Order("create_time DESC").
|
||||
Find(&convs).Error
|
||||
@ -65,7 +65,7 @@ func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*model.ChatSession
|
||||
// CheckDialogExists checks if a dialog exists with given tenant_id and dialog_id
|
||||
func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, error) {
|
||||
var count int64
|
||||
err := DB.Model(&model.Chat{}).
|
||||
err := DB.Model(&entity.Chat{}).
|
||||
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, dialogID, "1").
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
@ -75,8 +75,8 @@ func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, e
|
||||
}
|
||||
|
||||
// GetDialogByID gets dialog by ID
|
||||
func (dao *ChatSessionDAO) GetDialogByID(dialogID string) (*model.Chat, error) {
|
||||
var dialog model.Chat
|
||||
func (dao *ChatSessionDAO) GetDialogByID(dialogID string) (*entity.Chat, error) {
|
||||
var dialog entity.Chat
|
||||
err := DB.Where("id = ? AND status = ?", dialogID, "1").First(&dialog).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -89,6 +89,6 @@ func (dao *ChatSessionDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error)
|
||||
if len(dialogIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&model.ChatSession{})
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.ChatSession{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// ConnectorDAO connector data access object
|
||||
@ -41,7 +41,7 @@ type ConnectorListItem struct {
|
||||
func (dao *ConnectorDAO) ListByTenantID(tenantID string) ([]*ConnectorListItem, error) {
|
||||
var connectors []*ConnectorListItem
|
||||
|
||||
err := DB.Model(&model.Connector{}).
|
||||
err := DB.Model(&entity.Connector{}).
|
||||
Select("id", "name", "source", "status").
|
||||
Where("tenant_id = ?", tenantID).
|
||||
Find(&connectors).Error
|
||||
@ -54,8 +54,8 @@ func (dao *ConnectorDAO) ListByTenantID(tenantID string) ([]*ConnectorListItem,
|
||||
}
|
||||
|
||||
// GetByID get connector by ID
|
||||
func (dao *ConnectorDAO) GetByID(id string) (*model.Connector, error) {
|
||||
var connector model.Connector
|
||||
func (dao *ConnectorDAO) GetByID(id string) (*entity.Connector, error) {
|
||||
var connector entity.Connector
|
||||
err := DB.Where("id = ?", id).First(&connector).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -64,16 +64,16 @@ func (dao *ConnectorDAO) GetByID(id string) (*model.Connector, error) {
|
||||
}
|
||||
|
||||
// Create create a new connector
|
||||
func (dao *ConnectorDAO) Create(connector *model.Connector) error {
|
||||
func (dao *ConnectorDAO) Create(connector *entity.Connector) error {
|
||||
return DB.Create(connector).Error
|
||||
}
|
||||
|
||||
// UpdateByID update connector by ID
|
||||
func (dao *ConnectorDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.Connector{}).Where("id = ?", id).Updates(updates).Error
|
||||
return DB.Model(&entity.Connector{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteByID delete connector by ID
|
||||
func (dao *ConnectorDAO) DeleteByID(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&model.Connector{}).Error
|
||||
return DB.Where("id = ?", id).Delete(&entity.Connector{}).Error
|
||||
}
|
||||
|
||||
@ -22,11 +22,12 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/logger"
|
||||
"ragflow/internal/model"
|
||||
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
|
||||
@ -111,40 +112,40 @@ func InitDB() error {
|
||||
|
||||
// Auto migrate all models
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Tenant{},
|
||||
&model.UserTenant{},
|
||||
&model.File{},
|
||||
&model.File2Document{},
|
||||
&model.TenantLLM{},
|
||||
&model.Chat{},
|
||||
&model.ChatSession{},
|
||||
&model.Task{},
|
||||
&model.APIToken{},
|
||||
&model.API4Conversation{},
|
||||
&model.Knowledgebase{},
|
||||
&model.InvitationCode{},
|
||||
&model.Document{},
|
||||
&model.UserCanvas{},
|
||||
&model.CanvasTemplate{},
|
||||
&model.UserCanvasVersion{},
|
||||
&model.LLMFactories{},
|
||||
&model.LLM{},
|
||||
&model.TenantLangfuse{},
|
||||
&model.SystemSettings{},
|
||||
&model.Connector{},
|
||||
&model.Connector2Kb{},
|
||||
&model.SyncLogs{},
|
||||
&model.MCPServer{},
|
||||
&model.Memory{},
|
||||
&model.Search{},
|
||||
&model.PipelineOperationLog{},
|
||||
&model.EvaluationDataset{},
|
||||
&model.EvaluationCase{},
|
||||
&model.EvaluationRun{},
|
||||
&model.EvaluationResult{},
|
||||
&model.TimeRecord{},
|
||||
&model.License{},
|
||||
&entity.User{},
|
||||
&entity.Tenant{},
|
||||
&entity.UserTenant{},
|
||||
&entity.File{},
|
||||
&entity.File2Document{},
|
||||
&entity.TenantLLM{},
|
||||
&entity.Chat{},
|
||||
&entity.ChatSession{},
|
||||
&entity.Task{},
|
||||
&entity.APIToken{},
|
||||
&entity.API4Conversation{},
|
||||
&entity.Knowledgebase{},
|
||||
&entity.InvitationCode{},
|
||||
&entity.Document{},
|
||||
&entity.UserCanvas{},
|
||||
&entity.CanvasTemplate{},
|
||||
&entity.UserCanvasVersion{},
|
||||
&entity.LLMFactories{},
|
||||
&entity.LLM{},
|
||||
&entity.TenantLangfuse{},
|
||||
&entity.SystemSettings{},
|
||||
&entity.Connector{},
|
||||
&entity.Connector2Kb{},
|
||||
&entity.SyncLogs{},
|
||||
&entity.MCPServer{},
|
||||
&entity.Memory{},
|
||||
&entity.Search{},
|
||||
&entity.PipelineOperationLog{},
|
||||
&entity.EvaluationDataset{},
|
||||
&entity.EvaluationCase{},
|
||||
&entity.EvaluationRun{},
|
||||
&entity.EvaluationResult{},
|
||||
&entity.TimeRecord{},
|
||||
&entity.License{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
@ -223,7 +224,7 @@ func InitLLMFactory() error {
|
||||
status = "1"
|
||||
}
|
||||
|
||||
llmFactory := &model.LLMFactories{
|
||||
llmFactory := &entity.LLMFactories{
|
||||
Name: factory.Name,
|
||||
Logo: utility.StringPtr(factory.Logo),
|
||||
Tags: factory.Tags,
|
||||
@ -231,7 +232,7 @@ func InitLLMFactory() error {
|
||||
Status: &status,
|
||||
}
|
||||
|
||||
var existingFactory model.LLMFactories
|
||||
var existingFactory entity.LLMFactories
|
||||
result := db.Where("name = ?", factory.Name).First(&existingFactory)
|
||||
if result.Error != nil {
|
||||
if err := db.Create(llmFactory).Error; err != nil {
|
||||
@ -239,7 +240,7 @@ func InitLLMFactory() error {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{
|
||||
if err := db.Model(&entity.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{
|
||||
"logo": llmFactory.Logo,
|
||||
"tags": llmFactory.Tags,
|
||||
"rank": llmFactory.Rank,
|
||||
@ -251,7 +252,7 @@ func InitLLMFactory() error {
|
||||
|
||||
for _, llm := range factory.LLM {
|
||||
llmStatus := "1"
|
||||
llmModel := &model.LLM{
|
||||
llmModel := &entity.LLM{
|
||||
LLMName: llm.LLMName,
|
||||
ModelType: llm.ModelType,
|
||||
FID: factory.Name,
|
||||
@ -261,14 +262,14 @@ func InitLLMFactory() error {
|
||||
Status: &llmStatus,
|
||||
}
|
||||
|
||||
var existingLLM model.LLM
|
||||
var existingLLM entity.LLM
|
||||
result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM)
|
||||
if result.Error != nil {
|
||||
if err := db.Create(llmModel).Error; err != nil {
|
||||
log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err)
|
||||
}
|
||||
} else {
|
||||
if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{
|
||||
if err := db.Model(&entity.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{
|
||||
"model_type": llmModel.ModelType,
|
||||
"max_tokens": llmModel.MaxTokens,
|
||||
"tags": llmModel.Tags,
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// DocumentDAO document data access object
|
||||
@ -29,13 +29,13 @@ func NewDocumentDAO() *DocumentDAO {
|
||||
}
|
||||
|
||||
// Create create document
|
||||
func (dao *DocumentDAO) Create(document *model.Document) error {
|
||||
func (dao *DocumentDAO) Create(document *entity.Document) error {
|
||||
return DB.Create(document).Error
|
||||
}
|
||||
|
||||
// GetByID get document by ID
|
||||
func (dao *DocumentDAO) GetByID(id string) (*model.Document, error) {
|
||||
var document model.Document
|
||||
func (dao *DocumentDAO) GetByID(id string) (*entity.Document, error) {
|
||||
var document entity.Document
|
||||
err := DB.First(&document, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -44,11 +44,11 @@ func (dao *DocumentDAO) GetByID(id string) (*model.Document, error) {
|
||||
}
|
||||
|
||||
// GetByAuthorID get documents by author ID
|
||||
func (dao *DocumentDAO) GetByAuthorID(authorID string, offset, limit int) ([]*model.Document, int64, error) {
|
||||
var documents []*model.Document
|
||||
func (dao *DocumentDAO) GetByAuthorID(authorID string, offset, limit int) ([]*entity.Document, int64, error) {
|
||||
var documents []*entity.Document
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&model.Document{}).Where("created_by = ?", authorID)
|
||||
query := DB.Model(&entity.Document{}).Where("created_by = ?", authorID)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -58,21 +58,21 @@ func (dao *DocumentDAO) GetByAuthorID(authorID string, offset, limit int) ([]*mo
|
||||
}
|
||||
|
||||
// Update update document
|
||||
func (dao *DocumentDAO) Update(document *model.Document) error {
|
||||
func (dao *DocumentDAO) Update(document *entity.Document) error {
|
||||
return DB.Save(document).Error
|
||||
}
|
||||
|
||||
// Delete delete document
|
||||
func (dao *DocumentDAO) Delete(id string) error {
|
||||
return DB.Delete(&model.Document{}, "id = ?", id).Error
|
||||
return DB.Delete(&entity.Document{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// List list documents
|
||||
func (dao *DocumentDAO) List(offset, limit int) ([]*model.Document, int64, error) {
|
||||
var documents []*model.Document
|
||||
func (dao *DocumentDAO) List(offset, limit int) ([]*entity.Document, int64, error) {
|
||||
var documents []*entity.Document
|
||||
var total int64
|
||||
|
||||
if err := DB.Model(&model.Document{}).Count(&total).Error; err != nil {
|
||||
if err := DB.Model(&entity.Document{}).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@ -81,11 +81,11 @@ func (dao *DocumentDAO) List(offset, limit int) ([]*model.Document, int64, error
|
||||
}
|
||||
|
||||
// ListByKBID list documents by knowledge base ID
|
||||
func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*model.Document, int64, error) {
|
||||
var documents []*model.Document
|
||||
func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*entity.Document, int64, error) {
|
||||
var documents []*entity.Document
|
||||
var total int64
|
||||
|
||||
if err := DB.Model(&model.Document{}).Where("kb_id = ?", kbID).Count(&total).Error; err != nil {
|
||||
if err := DB.Model(&entity.Document{}).Where("kb_id = ?", kbID).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@ -95,21 +95,21 @@ func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*model.Doc
|
||||
|
||||
// DeleteByTenantID deletes all documents by tenant ID (hard delete)
|
||||
func (dao *DocumentDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.Document{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.Document{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetAllDocIDsByKBIDs gets all document IDs by knowledge base IDs
|
||||
func (dao *DocumentDAO) GetAllDocIDsByKBIDs(kbIDs []string) ([]map[string]string, error) {
|
||||
var docs []struct {
|
||||
ID string `gorm:"column:id"`
|
||||
ID string `gorm:"column:id"`
|
||||
KbID string `gorm:"column:kb_id"`
|
||||
}
|
||||
err := DB.Model(&model.Document{}).Select("id, kb_id").Where("kb_id IN ?", kbIDs).Find(&docs).Error
|
||||
err := DB.Model(&entity.Document{}).Select("id, kb_id").Where("kb_id IN ?", kbIDs).Find(&docs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
result := make([]map[string]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
result[i] = map[string]string{"id": doc.ID, "kb_id": doc.KbID}
|
||||
|
||||
@ -17,11 +17,10 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// FileDAO file data access object
|
||||
@ -33,8 +32,8 @@ func NewFileDAO() *FileDAO {
|
||||
}
|
||||
|
||||
// GetByID gets file by ID
|
||||
func (dao *FileDAO) GetByID(id string) (*model.File, error) {
|
||||
var file model.File
|
||||
func (dao *FileDAO) GetByID(id string) (*entity.File, error) {
|
||||
var file entity.File
|
||||
err := DB.Where("id = ?", id).First(&file).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -43,11 +42,11 @@ func (dao *FileDAO) GetByID(id string) (*model.File, error) {
|
||||
}
|
||||
|
||||
// GetByPfID gets files by parent folder ID with pagination and filtering
|
||||
func (dao *FileDAO) GetByPfID(tenantID, pfID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*model.File, int64, error) {
|
||||
var files []*model.File
|
||||
func (dao *FileDAO) GetByPfID(tenantID, pfID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*entity.File, int64, error) {
|
||||
var files []*entity.File
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&model.File{}).
|
||||
query := DB.Model(&entity.File{}).
|
||||
Where("tenant_id = ? AND parent_id = ? AND id != ?", tenantID, pfID, pfID)
|
||||
|
||||
// Apply keyword filter
|
||||
@ -83,8 +82,8 @@ func (dao *FileDAO) GetByPfID(tenantID, pfID string, page, pageSize int, orderby
|
||||
}
|
||||
|
||||
// GetRootFolder gets or creates root folder for tenant
|
||||
func (dao *FileDAO) GetRootFolder(tenantID string) (*model.File, error) {
|
||||
var file model.File
|
||||
func (dao *FileDAO) GetRootFolder(tenantID string) (*entity.File, error) {
|
||||
var file entity.File
|
||||
err := DB.Where("tenant_id = ? AND parent_id = id", tenantID).First(&file).Error
|
||||
if err == nil {
|
||||
return &file, nil
|
||||
@ -92,7 +91,7 @@ func (dao *FileDAO) GetRootFolder(tenantID string) (*model.File, error) {
|
||||
|
||||
// Create root folder if not exists
|
||||
fileID := generateUUID()
|
||||
file = model.File{
|
||||
file = entity.File{
|
||||
ID: fileID,
|
||||
ParentID: fileID,
|
||||
TenantID: tenantID,
|
||||
@ -110,14 +109,14 @@ func (dao *FileDAO) GetRootFolder(tenantID string) (*model.File, error) {
|
||||
}
|
||||
|
||||
// GetParentFolder gets parent folder of a file
|
||||
func (dao *FileDAO) GetParentFolder(fileID string) (*model.File, error) {
|
||||
var file model.File
|
||||
func (dao *FileDAO) GetParentFolder(fileID string) (*entity.File, error) {
|
||||
var file entity.File
|
||||
err := DB.Where("id = ?", fileID).First(&file).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parentFile model.File
|
||||
var parentFile entity.File
|
||||
err = DB.Where("id = ?", file.ParentID).First(&parentFile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -126,8 +125,8 @@ func (dao *FileDAO) GetParentFolder(fileID string) (*model.File, error) {
|
||||
}
|
||||
|
||||
// ListByParentID lists all files by parent ID (including subfolders)
|
||||
func (dao *FileDAO) ListByParentID(parentID string) ([]*model.File, error) {
|
||||
var files []*model.File
|
||||
func (dao *FileDAO) ListByParentID(parentID string) ([]*entity.File, error) {
|
||||
var files []*entity.File
|
||||
err := DB.Where("parent_id = ? AND id != ?", parentID, parentID).Find(&files).Error
|
||||
return files, err
|
||||
}
|
||||
@ -138,7 +137,7 @@ func (dao *FileDAO) GetFolderSize(folderID string) (int64, error) {
|
||||
|
||||
var dfs func(parentID string) error
|
||||
dfs = func(parentID string) error {
|
||||
var files []*model.File
|
||||
var files []*entity.File
|
||||
if err := DB.Select("id", "size", "type").
|
||||
Where("parent_id = ? AND id != ?", parentID, parentID).
|
||||
Find(&files).Error; err != nil {
|
||||
@ -165,19 +164,19 @@ func (dao *FileDAO) GetFolderSize(folderID string) (int64, error) {
|
||||
// HasChildFolder checks if folder has child folders
|
||||
func (dao *FileDAO) HasChildFolder(folderID string) (bool, error) {
|
||||
var count int64
|
||||
err := DB.Model(&model.File{}).
|
||||
err := DB.Model(&entity.File{}).
|
||||
Where("parent_id = ? AND id != ? AND type = ?", folderID, folderID, "folder").
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GetAllParentFolders gets all parent folders in path (from current to root)
|
||||
func (dao *FileDAO) GetAllParentFolders(startID string) ([]*model.File, error) {
|
||||
var parentFolders []*model.File
|
||||
func (dao *FileDAO) GetAllParentFolders(startID string) ([]*entity.File, error) {
|
||||
var parentFolders []*entity.File
|
||||
currentID := startID
|
||||
|
||||
for currentID != "" {
|
||||
var file model.File
|
||||
var file entity.File
|
||||
err := DB.Where("id = ?", currentID).First(&file).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -196,13 +195,13 @@ func (dao *FileDAO) GetAllParentFolders(startID string) ([]*model.File, error) {
|
||||
}
|
||||
|
||||
// Create creates a new file
|
||||
func (dao *FileDAO) Create(file *model.File) error {
|
||||
func (dao *FileDAO) Create(file *entity.File) error {
|
||||
return DB.Create(file).Error
|
||||
}
|
||||
|
||||
// DeleteByTenantID deletes all files by tenant ID (hard delete)
|
||||
func (dao *FileDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.File{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.File{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@ -211,14 +210,14 @@ func (dao *FileDAO) DeleteByIDs(ids []string) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := DB.Unscoped().Where("id IN ?", ids).Delete(&model.File{})
|
||||
result := DB.Unscoped().Where("id IN ?", ids).Delete(&entity.File{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetAllIDsByTenantID gets all file IDs by tenant ID
|
||||
func (dao *FileDAO) GetAllIDsByTenantID(tenantID string) ([]string, error) {
|
||||
var ids []string
|
||||
err := DB.Model(&model.File{}).Where("tenant_id = ?", tenantID).Pluck("id", &ids).Error
|
||||
err := DB.Model(&entity.File{}).Where("tenant_id = ?", tenantID).Pluck("id", &ids).Error
|
||||
return ids, err
|
||||
}
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// File2DocumentDAO file to document mapping data access object
|
||||
@ -32,7 +32,7 @@ func NewFile2DocumentDAO() *File2DocumentDAO {
|
||||
func (dao *File2DocumentDAO) GetKBInfoByFileID(fileID string) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
rows, err := DB.Model(&model.File{}).
|
||||
rows, err := DB.Model(&entity.File{}).
|
||||
Select("knowledgebase.id, knowledgebase.name, file2document.document_id").
|
||||
Joins("JOIN file2document ON file2document.file_id = ?", fileID).
|
||||
Joins("JOIN document ON document.id = file2document.document_id").
|
||||
|
||||
@ -18,7 +18,8 @@ package dao
|
||||
|
||||
import (
|
||||
"path"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -33,29 +34,29 @@ func NewKnowledgebaseDAO() *KnowledgebaseDAO {
|
||||
}
|
||||
|
||||
// Create creates a new knowledge base record
|
||||
func (dao *KnowledgebaseDAO) Create(kb *model.Knowledgebase) error {
|
||||
func (dao *KnowledgebaseDAO) Create(kb *entity.Knowledgebase) error {
|
||||
return DB.Create(kb).Error
|
||||
}
|
||||
|
||||
// Update updates a knowledge base record
|
||||
func (dao *KnowledgebaseDAO) Update(kb *model.Knowledgebase) error {
|
||||
func (dao *KnowledgebaseDAO) Update(kb *entity.Knowledgebase) error {
|
||||
return DB.Save(kb).Error
|
||||
}
|
||||
|
||||
// UpdateByID updates a knowledge base by ID with the given fields
|
||||
func (dao *KnowledgebaseDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Updates(updates).Error
|
||||
return DB.Model(&entity.Knowledgebase{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete soft deletes a knowledge base by setting status to invalid
|
||||
func (dao *KnowledgebaseDAO) Delete(id string) error {
|
||||
return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Update("status", string(model.StatusInvalid)).Error
|
||||
return DB.Model(&entity.Knowledgebase{}).Where("id = ?", id).Update("status", string(entity.StatusInvalid)).Error
|
||||
}
|
||||
|
||||
// GetByID retrieves a knowledge base by ID
|
||||
func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error
|
||||
func (dao *KnowledgebaseDAO) GetByID(id string) (*entity.Knowledgebase, error) {
|
||||
var kb entity.Knowledgebase
|
||||
err := DB.Where("id = ? AND status = ?", id, string(entity.StatusValid)).First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -63,9 +64,9 @@ func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) {
|
||||
}
|
||||
|
||||
// GetByIDAndTenantID retrieves a knowledge base by ID and tenant ID
|
||||
func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, string(model.StatusValid)).First(&kb).Error
|
||||
func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*entity.Knowledgebase, error) {
|
||||
var kb entity.Knowledgebase
|
||||
err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, string(entity.StatusValid)).First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -73,16 +74,16 @@ func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Kno
|
||||
}
|
||||
|
||||
// GetByIDs retrieves multiple knowledge bases by IDs
|
||||
func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Where("id IN ? AND status = ?", ids, string(model.StatusValid)).Find(&kbs).Error
|
||||
func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*entity.Knowledgebase, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
err := DB.Where("id IN ? AND status = ?", ids, string(entity.StatusValid)).Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// GetByName retrieves a knowledge base by name and tenant ID
|
||||
func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
err := DB.Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, string(model.StatusValid)).First(&kb).Error
|
||||
func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*entity.Knowledgebase, error) {
|
||||
var kb entity.Knowledgebase
|
||||
err := DB.Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, string(entity.StatusValid)).First(&kb).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -90,16 +91,16 @@ func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*model.Knowledgeb
|
||||
}
|
||||
|
||||
// GetByCreatedBy retrieves knowledge bases created by a specific user
|
||||
func (dao *KnowledgebaseDAO) GetByCreatedBy(createdBy string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Where("created_by = ? AND status = ?", createdBy, string(model.StatusValid)).Find(&kbs).Error
|
||||
func (dao *KnowledgebaseDAO) GetByCreatedBy(createdBy string) ([]*entity.Knowledgebase, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
err := DB.Where("created_by = ? AND status = ?", createdBy, string(entity.StatusValid)).Find(&kbs).Error
|
||||
return kbs, err
|
||||
}
|
||||
|
||||
// Query retrieves knowledge bases with filters
|
||||
func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
query := DB.Where("status = ?", string(model.StatusValid))
|
||||
func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*entity.Knowledgebase, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
query := DB.Where("status = ?", string(entity.StatusValid))
|
||||
|
||||
for key, value := range filters {
|
||||
if value != nil && value != "" {
|
||||
@ -112,9 +113,9 @@ func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*model.Kno
|
||||
}
|
||||
|
||||
// QueryOne retrieves a single knowledge base with filters
|
||||
func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*model.Knowledgebase, error) {
|
||||
var kb model.Knowledgebase
|
||||
query := DB.Where("status = ?", string(model.StatusValid))
|
||||
func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*entity.Knowledgebase, error) {
|
||||
var kb entity.Knowledgebase
|
||||
query := DB.Where("status = ?", string(entity.StatusValid))
|
||||
|
||||
for key, value := range filters {
|
||||
if value != nil && value != "" {
|
||||
@ -132,7 +133,7 @@ func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*model.Kn
|
||||
// Count returns the count of knowledge bases matching the filters
|
||||
func (dao *KnowledgebaseDAO) Count(filters map[string]interface{}) (int64, error) {
|
||||
var count int64
|
||||
query := DB.Model(&model.Knowledgebase{}).Where("status = ?", string(model.StatusValid))
|
||||
query := DB.Model(&entity.Knowledgebase{}).Where("status = ?", string(entity.StatusValid))
|
||||
|
||||
for key, value := range filters {
|
||||
if value != nil && value != "" {
|
||||
@ -146,11 +147,11 @@ func (dao *KnowledgebaseDAO) Count(filters map[string]interface{}) (int64, error
|
||||
|
||||
// GetByTenantIDs retrieves knowledge bases by tenant IDs with pagination
|
||||
// This matches the Python get_by_tenant_ids method
|
||||
func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, keywords, parserID string) ([]*model.KnowledgebaseListItem, int64, error) {
|
||||
var kbs []*model.KnowledgebaseListItem
|
||||
func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, keywords, parserID string) ([]*entity.KnowledgebaseListItem, int64, error) {
|
||||
var kbs []*entity.KnowledgebaseListItem
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&model.Knowledgebase{}).
|
||||
query := DB.Model(&entity.Knowledgebase{}).
|
||||
Select(`knowledgebase.id, knowledgebase.avatar, knowledgebase.name,
|
||||
knowledgebase.language, knowledgebase.description, knowledgebase.tenant_id,
|
||||
knowledgebase.permission, knowledgebase.doc_num, knowledgebase.token_num,
|
||||
@ -158,7 +159,7 @@ func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, p
|
||||
user.nickname, user.avatar as tenant_avatar, knowledgebase.update_time`).
|
||||
Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id").
|
||||
Where("((knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?) AND knowledgebase.status = ?",
|
||||
tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid))
|
||||
tenantIDs, string(entity.TenantPermissionTeam), userID, string(entity.StatusValid))
|
||||
|
||||
if keywords != "" {
|
||||
query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%")
|
||||
@ -194,12 +195,12 @@ func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, p
|
||||
|
||||
// GetAllByTenantIDs retrieves all permitted knowledge bases by tenant IDs
|
||||
// This matches the Python get_all_kb_by_tenant_ids method
|
||||
func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string) ([]*entity.Knowledgebase, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
|
||||
err := DB.Where(
|
||||
"(tenant_id IN ? AND permission = ?) OR tenant_id = ?",
|
||||
tenantIDs, string(model.TenantPermissionTeam), userID,
|
||||
tenantIDs, string(entity.TenantPermissionTeam), userID,
|
||||
).Order("create_time ASC").Find(&kbs).Error
|
||||
|
||||
return kbs, err
|
||||
@ -207,8 +208,8 @@ func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string
|
||||
|
||||
// GetDetail retrieves detailed knowledge base information with joined pipeline data
|
||||
// This matches the Python get_detail method
|
||||
func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*model.KnowledgebaseDetail, error) {
|
||||
var detail model.KnowledgebaseDetail
|
||||
func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*entity.KnowledgebaseDetail, error) {
|
||||
var detail entity.KnowledgebaseDetail
|
||||
|
||||
err := DB.Table("knowledgebase").
|
||||
Select(`knowledgebase.id, knowledgebase.embd_id, knowledgebase.avatar, knowledgebase.name,
|
||||
@ -222,7 +223,7 @@ func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*model.KnowledgebaseDetail,
|
||||
knowledgebase.mindmap_task_id, knowledgebase.mindmap_task_finish_at,
|
||||
knowledgebase.create_time, knowledgebase.update_time`).
|
||||
Joins("LEFT JOIN user_canvas ON knowledgebase.pipeline_id = user_canvas.id").
|
||||
Where("knowledgebase.id = ? AND knowledgebase.status = ?", kbID, string(model.StatusValid)).
|
||||
Where("knowledgebase.id = ? AND knowledgebase.status = ?", kbID, string(entity.StatusValid)).
|
||||
Scan(&detail).Error
|
||||
|
||||
if err != nil {
|
||||
@ -239,7 +240,7 @@ func (dao *KnowledgebaseDAO) Accessible(kbID, userID string) bool {
|
||||
err := DB.Table("knowledgebase").
|
||||
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
|
||||
Where("knowledgebase.id = ? AND user_tenant.user_id = ? AND knowledgebase.status = ?",
|
||||
kbID, userID, string(model.StatusValid)).
|
||||
kbID, userID, string(entity.StatusValid)).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
@ -252,8 +253,8 @@ func (dao *KnowledgebaseDAO) Accessible(kbID, userID string) bool {
|
||||
// This matches the Python accessible4deletion method
|
||||
func (dao *KnowledgebaseDAO) Accessible4Deletion(kbID, userID string) bool {
|
||||
var count int64
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("id = ? AND created_by = ? AND status = ?", kbID, userID, string(model.StatusValid)).
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Where("id = ? AND created_by = ? AND status = ?", kbID, userID, string(entity.StatusValid)).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
@ -270,8 +271,8 @@ func (dao *KnowledgebaseDAO) DuplicateName(name, tenantID string) string {
|
||||
currentName := name
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
var count int64
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("LOWER(name) = ? AND tenant_id = ? AND status = ?", strings.ToLower(currentName), tenantID, string(model.StatusValid)).
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Where("LOWER(name) = ? AND tenant_id = ? AND status = ?", strings.ToLower(currentName), tenantID, string(entity.StatusValid)).
|
||||
Count(&count).Error
|
||||
if err != nil || count == 0 {
|
||||
return currentName
|
||||
@ -315,7 +316,7 @@ func splitNameCounter(name string) (string, int) {
|
||||
func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
return DB.Model(&entity.Knowledgebase{}).
|
||||
Where("id = ?", kbID).
|
||||
Updates(map[string]interface{}{
|
||||
"doc_num": DB.Raw("doc_num + 1"),
|
||||
@ -329,7 +330,7 @@ func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error {
|
||||
func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
return DB.Model(&entity.Knowledgebase{}).
|
||||
Where("id = ?", kbID).
|
||||
Updates(map[string]interface{}{
|
||||
"doc_num": DB.Raw("doc_num - ?", docNum),
|
||||
@ -344,8 +345,8 @@ func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum,
|
||||
// This matches the Python get_kb_ids method
|
||||
func (dao *KnowledgebaseDAO) GetKBIDsByTenantID(tenantID string) ([]string, error) {
|
||||
var kbIDs []string
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("tenant_id = ? AND status = ?", tenantID, string(model.StatusValid)).
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Where("tenant_id = ? AND status = ?", tenantID, string(entity.StatusValid)).
|
||||
Pluck("id", &kbIDs).Error
|
||||
return kbIDs, err
|
||||
}
|
||||
@ -354,8 +355,8 @@ func (dao *KnowledgebaseDAO) GetKBIDsByTenantID(tenantID string) ([]string, erro
|
||||
// This matches the Python get_all_ids method
|
||||
func (dao *KnowledgebaseDAO) GetAllIDs() ([]string, error) {
|
||||
var kbIDs []string
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
Where("status = ?", string(model.StatusValid)).
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Where("status = ?", string(entity.StatusValid)).
|
||||
Pluck("id", &kbIDs).Error
|
||||
return kbIDs, err
|
||||
}
|
||||
@ -363,13 +364,13 @@ func (dao *KnowledgebaseDAO) GetAllIDs() ([]string, error) {
|
||||
// UpdateParserConfig updates the parser configuration with deep merge
|
||||
// This matches the Python update_parser_config method
|
||||
func (dao *KnowledgebaseDAO) UpdateParserConfig(id string, config map[string]interface{}) error {
|
||||
var kb model.Knowledgebase
|
||||
if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil {
|
||||
var kb entity.Knowledgebase
|
||||
if err := DB.Where("id = ? AND status = ?", id, string(entity.StatusValid)).First(&kb).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mergedConfig := mergeConfig(kb.ParserConfig, config)
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
return DB.Model(&entity.Knowledgebase{}).
|
||||
Where("id = ?", id).
|
||||
Update("parser_config", mergedConfig).Error
|
||||
}
|
||||
@ -377,14 +378,14 @@ func (dao *KnowledgebaseDAO) UpdateParserConfig(id string, config map[string]int
|
||||
// DeleteFieldMap removes the field_map from parser_config
|
||||
// This matches the Python delete_field_map method
|
||||
func (dao *KnowledgebaseDAO) DeleteFieldMap(id string) error {
|
||||
var kb model.Knowledgebase
|
||||
if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil {
|
||||
var kb entity.Knowledgebase
|
||||
if err := DB.Where("id = ? AND status = ?", id, string(entity.StatusValid)).First(&kb).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if kb.ParserConfig != nil {
|
||||
delete(kb.ParserConfig, "field_map")
|
||||
return DB.Model(&model.Knowledgebase{}).
|
||||
return DB.Model(&entity.Knowledgebase{}).
|
||||
Where("id = ?", id).
|
||||
Update("parser_config", kb.ParserConfig).Error
|
||||
}
|
||||
@ -416,9 +417,9 @@ func (dao *KnowledgebaseDAO) GetFieldMap(ids []string) (map[string]interface{},
|
||||
|
||||
// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID with tenant join
|
||||
// This matches the Python get_kb_by_id method
|
||||
func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*entity.Knowledgebase, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
|
||||
Where("knowledgebase.id = ? AND user_tenant.user_id = ?", kbID, userID).
|
||||
Limit(1).
|
||||
@ -428,9 +429,9 @@ func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*model.K
|
||||
|
||||
// GetKBByNameAndUserID retrieves a knowledge base by name and user ID with tenant join
|
||||
// This matches the Python get_kb_by_name method
|
||||
func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*entity.Knowledgebase, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id").
|
||||
Where("knowledgebase.name = ? AND user_tenant.user_id = ?", kbName, userID).
|
||||
Limit(1).
|
||||
@ -440,13 +441,13 @@ func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*mod
|
||||
|
||||
// GetList retrieves knowledge bases with filtering by ID and name
|
||||
// This matches the Python get_list method
|
||||
func (dao *KnowledgebaseDAO) GetList(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, error) {
|
||||
var kbs []*model.Knowledgebase
|
||||
func (dao *KnowledgebaseDAO) GetList(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, name string) ([]*entity.Knowledgebase, int64, error) {
|
||||
var kbs []*entity.Knowledgebase
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&model.Knowledgebase{}).
|
||||
query := DB.Model(&entity.Knowledgebase{}).
|
||||
Where("((tenant_id IN ? AND permission = ?) OR tenant_id = ?) AND status = ?",
|
||||
tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid))
|
||||
tenantIDs, string(entity.TenantPermissionTeam), userID, string(entity.StatusValid))
|
||||
|
||||
if id != "" {
|
||||
query = query.Where("id = ?", id)
|
||||
@ -518,14 +519,14 @@ func mergeConfig(old, new map[string]interface{}) map[string]interface{} {
|
||||
|
||||
// DeleteByTenantID deletes all knowledge bases by tenant ID (hard delete)
|
||||
func (dao *KnowledgebaseDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.Knowledgebase{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.Knowledgebase{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetKBIDsByTenantID gets all knowledge base IDs by tenant ID
|
||||
func (dao *KnowledgebaseDAO) GetKBIDsByTenantIDSimple(tenantID string) ([]string, error) {
|
||||
var kbIDs []string
|
||||
err := DB.Model(&model.Knowledgebase{}).
|
||||
err := DB.Model(&entity.Knowledgebase{}).
|
||||
Where("tenant_id = ?", tenantID).
|
||||
Pluck("id", &kbIDs).Error
|
||||
return kbIDs, err
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -31,7 +31,7 @@ func NewLicenseDAO() *LicenseDAO {
|
||||
|
||||
// Create creates a new license record
|
||||
func (dao *LicenseDAO) Create(licenseID, licenseStr string) error {
|
||||
license := model.License{
|
||||
license := entity.License{
|
||||
ID: licenseID,
|
||||
License: licenseStr,
|
||||
CreatedAt: time.Now(),
|
||||
@ -40,8 +40,8 @@ func (dao *LicenseDAO) Create(licenseID, licenseStr string) error {
|
||||
}
|
||||
|
||||
// GetLatest gets the latest license record by creation time
|
||||
func (dao *LicenseDAO) GetLatest() (*model.License, error) {
|
||||
var license model.License
|
||||
func (dao *LicenseDAO) GetLatest() (*entity.License, error) {
|
||||
var license entity.License
|
||||
err := DB.Order("created_at DESC").First(&license).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// LLMDAO LLM data access object
|
||||
@ -29,8 +29,8 @@ func NewLLMDAO() *LLMDAO {
|
||||
}
|
||||
|
||||
// GetAll gets all LLMs
|
||||
func (dao *LLMDAO) GetAll() ([]*model.LLM, error) {
|
||||
var llms []*model.LLM
|
||||
func (dao *LLMDAO) GetAll() ([]*entity.LLM, error) {
|
||||
var llms []*entity.LLM
|
||||
err := DB.Find(&llms).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -39,8 +39,8 @@ func (dao *LLMDAO) GetAll() ([]*model.LLM, error) {
|
||||
}
|
||||
|
||||
// GetAllValid gets all valid LLMs
|
||||
func (dao *LLMDAO) GetAllValid() ([]*model.LLM, error) {
|
||||
var llms []*model.LLM
|
||||
func (dao *LLMDAO) GetAllValid() ([]*entity.LLM, error) {
|
||||
var llms []*entity.LLM
|
||||
err := DB.Where("status = ?", "1").Find(&llms).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -49,8 +49,8 @@ func (dao *LLMDAO) GetAllValid() ([]*model.LLM, error) {
|
||||
}
|
||||
|
||||
// GetByFactory gets LLMs by factory
|
||||
func (dao *LLMDAO) GetByFactory(factory string) ([]*model.LLM, error) {
|
||||
var llms []*model.LLM
|
||||
func (dao *LLMDAO) GetByFactory(factory string) ([]*entity.LLM, error) {
|
||||
var llms []*entity.LLM
|
||||
err := DB.Where("fid = ?", factory).Find(&llms).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -59,8 +59,8 @@ func (dao *LLMDAO) GetByFactory(factory string) ([]*model.LLM, error) {
|
||||
}
|
||||
|
||||
// GetByFactoryAndName gets LLM by factory and name
|
||||
func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*model.LLM, error) {
|
||||
var llm model.LLM
|
||||
func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*entity.LLM, error) {
|
||||
var llm entity.LLM
|
||||
err := DB.Where("fid = ? AND llm_name = ?", factory, name).First(&llm).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -77,8 +77,8 @@ func NewLLMFactoryDAO() *LLMFactoryDAO {
|
||||
}
|
||||
|
||||
// GetAllValid gets all valid LLM factories
|
||||
func (dao *LLMFactoryDAO) GetAllValid() ([]*model.LLMFactories, error) {
|
||||
var factories []*model.LLMFactories
|
||||
func (dao *LLMFactoryDAO) GetAllValid() ([]*entity.LLMFactories, error) {
|
||||
var factories []*entity.LLMFactories
|
||||
err := DB.Where("status = ?", "1").Find(&factories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -87,8 +87,8 @@ func (dao *LLMFactoryDAO) GetAllValid() ([]*model.LLMFactories, error) {
|
||||
}
|
||||
|
||||
// GetByName gets LLM factory by name
|
||||
func (dao *LLMFactoryDAO) GetByName(name string) (*model.LLMFactories, error) {
|
||||
var factory model.LLMFactories
|
||||
func (dao *LLMFactoryDAO) GetByName(name string) (*entity.LLMFactories, error) {
|
||||
var factory entity.LLMFactories
|
||||
err := DB.Where("name = ?", name).First(&factory).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -21,9 +21,8 @@ package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// Memory type bit flag constants, consistent with Python MemoryType enum
|
||||
@ -111,7 +110,7 @@ func NewMemoryDAO() *MemoryDAO {
|
||||
//
|
||||
// Returns:
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) Create(memory *model.Memory) error {
|
||||
func (dao *MemoryDAO) Create(memory *entity.Memory) error {
|
||||
return DB.Create(memory).Error
|
||||
}
|
||||
|
||||
@ -123,8 +122,8 @@ func (dao *MemoryDAO) Create(memory *model.Memory) error {
|
||||
// Returns:
|
||||
// - *model.Memory: Memory model pointer
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) {
|
||||
var memory model.Memory
|
||||
func (dao *MemoryDAO) GetByID(id string) (*entity.Memory, error) {
|
||||
var memory entity.Memory
|
||||
err := DB.Where("id = ?", id).First(&memory).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -140,8 +139,8 @@ func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) {
|
||||
// Returns:
|
||||
// - []*model.Memory: Memory model pointer array
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) {
|
||||
var memories []*model.Memory
|
||||
func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*entity.Memory, error) {
|
||||
var memories []*entity.Memory
|
||||
err := DB.Where("tenant_id = ?", tenantID).Find(&memories).Error
|
||||
return memories, err
|
||||
}
|
||||
@ -156,8 +155,8 @@ func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) {
|
||||
// Returns:
|
||||
// - []*model.Memory: Matching memory list (for existence check)
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model.Memory, error) {
|
||||
var memories []*model.Memory
|
||||
func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*entity.Memory, error) {
|
||||
var memories []*entity.Memory
|
||||
err := DB.Where("name = ? AND tenant_id = ?", name, tenantID).Find(&memories).Error
|
||||
return memories, err
|
||||
}
|
||||
@ -170,8 +169,8 @@ func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model
|
||||
// Returns:
|
||||
// - []*model.Memory: Memory model pointer array
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByIDs(ids []string) ([]*model.Memory, error) {
|
||||
var memories []*model.Memory
|
||||
func (dao *MemoryDAO) GetByIDs(ids []string) ([]*entity.Memory, error) {
|
||||
var memories []*entity.Memory
|
||||
err := DB.Where("id IN ?", ids).Find(&memories).Error
|
||||
return memories, err
|
||||
}
|
||||
@ -217,7 +216,7 @@ func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) erro
|
||||
}
|
||||
}
|
||||
|
||||
return DB.Model(&model.Memory{}).Where("id = ?", id).Updates(updates).Error
|
||||
return DB.Model(&entity.Memory{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteByID deletes a memory by ID
|
||||
@ -232,7 +231,7 @@ func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) erro
|
||||
//
|
||||
// err := dao.DeleteByID("memory123")
|
||||
func (dao *MemoryDAO) DeleteByID(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&model.Memory{}).Error
|
||||
return DB.Where("id = ?", id).Delete(&entity.Memory{}).Error
|
||||
}
|
||||
|
||||
// GetWithOwnerNameByID retrieves a memory with owner name by ID
|
||||
@ -248,7 +247,7 @@ func (dao *MemoryDAO) DeleteByID(id string) error {
|
||||
// Example:
|
||||
//
|
||||
// memory, err := dao.GetWithOwnerNameByID("memory123")
|
||||
func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, error) {
|
||||
func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*entity.MemoryListItem, error) {
|
||||
querySQL := `
|
||||
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
|
||||
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
|
||||
@ -262,7 +261,7 @@ func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, er
|
||||
`
|
||||
|
||||
var rawResult struct {
|
||||
model.Memory
|
||||
entity.Memory
|
||||
OwnerName *string `gorm:"column:owner_name"`
|
||||
}
|
||||
|
||||
@ -270,7 +269,7 @@ func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.MemoryListItem{
|
||||
return &entity.MemoryListItem{
|
||||
Memory: rawResult.Memory,
|
||||
OwnerName: rawResult.OwnerName,
|
||||
}, nil
|
||||
@ -296,7 +295,7 @@ func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, er
|
||||
// Example:
|
||||
//
|
||||
// memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10)
|
||||
func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*model.MemoryListItem, int64, error) {
|
||||
func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*entity.MemoryListItem, int64, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
@ -350,7 +349,7 @@ func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, stor
|
||||
queryArgs := append(args, pageSize, offset)
|
||||
|
||||
var rawResults []struct {
|
||||
model.Memory
|
||||
entity.Memory
|
||||
OwnerName *string `gorm:"column:owner_name"`
|
||||
}
|
||||
|
||||
@ -358,9 +357,9 @@ func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, stor
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
memories := make([]*model.MemoryListItem, len(rawResults))
|
||||
memories := make([]*entity.MemoryListItem, len(rawResults))
|
||||
for i, r := range rawResults {
|
||||
memories[i] = &model.MemoryListItem{
|
||||
memories[i] = &entity.MemoryListItem{
|
||||
Memory: r.Memory,
|
||||
OwnerName: r.OwnerName,
|
||||
}
|
||||
|
||||
@ -17,9 +17,8 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// SearchDAO search data access object
|
||||
@ -31,12 +30,12 @@ func NewSearchDAO() *SearchDAO {
|
||||
}
|
||||
|
||||
// ListByTenantIDs list searches by tenant IDs with pagination and filtering
|
||||
func (dao *SearchDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*model.Search, int64, error) {
|
||||
var searches []*model.Search
|
||||
func (dao *SearchDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*entity.Search, int64, error) {
|
||||
var searches []*entity.Search
|
||||
var total int64
|
||||
|
||||
// Build query with join to user table for nickname and avatar
|
||||
query := DB.Model(&model.Search{}).
|
||||
query := DB.Model(&entity.Search{}).
|
||||
Select(`
|
||||
search.*,
|
||||
user.nickname,
|
||||
@ -78,11 +77,11 @@ func (dao *SearchDAO) ListByTenantIDs(tenantIDs []string, userID string, page, p
|
||||
}
|
||||
|
||||
// ListByOwnerIDs list searches by owner IDs with filtering (manual pagination)
|
||||
func (dao *SearchDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby string, desc bool, keywords string) ([]*model.Search, int64, error) {
|
||||
var searches []*model.Search
|
||||
func (dao *SearchDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby string, desc bool, keywords string) ([]*entity.Search, int64, error) {
|
||||
var searches []*entity.Search
|
||||
|
||||
// Build query with join to user table
|
||||
query := DB.Model(&model.Search{}).
|
||||
query := DB.Model(&entity.Search{}).
|
||||
Select(`
|
||||
search.*,
|
||||
user.nickname,
|
||||
@ -117,8 +116,8 @@ func (dao *SearchDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby s
|
||||
}
|
||||
|
||||
// GetByID gets search by ID
|
||||
func (dao *SearchDAO) GetByID(id string) (*model.Search, error) {
|
||||
var search model.Search
|
||||
func (dao *SearchDAO) GetByID(id string) (*entity.Search, error) {
|
||||
var search entity.Search
|
||||
err := DB.Where("id = ?", id).First(&search).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -18,10 +18,9 @@ package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"ragflow/internal/entity"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@ -35,8 +34,8 @@ func NewSystemSettingsDAO() *SystemSettingsDAO {
|
||||
|
||||
// GetAll get all system settings
|
||||
// Returns all system settings records from database
|
||||
func (d *SystemSettingsDAO) GetAll() ([]model.SystemSettings, error) {
|
||||
var settings []model.SystemSettings
|
||||
func (d *SystemSettingsDAO) GetAll() ([]entity.SystemSettings, error) {
|
||||
var settings []entity.SystemSettings
|
||||
err := DB.Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -46,8 +45,8 @@ func (d *SystemSettingsDAO) GetAll() ([]model.SystemSettings, error) {
|
||||
|
||||
// GetByName get system settings by name
|
||||
// Returns settings records that match the given name
|
||||
func (d *SystemSettingsDAO) GetByName(name string) ([]model.SystemSettings, error) {
|
||||
var settings []model.SystemSettings
|
||||
func (d *SystemSettingsDAO) GetByName(name string) ([]entity.SystemSettings, error) {
|
||||
var settings []entity.SystemSettings
|
||||
err := DB.Where("name = ?", name).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -57,11 +56,11 @@ func (d *SystemSettingsDAO) GetByName(name string) ([]model.SystemSettings, erro
|
||||
|
||||
// UpdateByName update system settings by name
|
||||
// Updates the setting with the given name using the provided data
|
||||
func (d *SystemSettingsDAO) UpdateByName(name string, setting *model.SystemSettings) error {
|
||||
func (d *SystemSettingsDAO) UpdateByName(name string, setting *entity.SystemSettings) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
|
||||
return DB.Model(&model.SystemSettings{}).
|
||||
return DB.Model(&entity.SystemSettings{}).
|
||||
Where("name = ?", name).
|
||||
Updates(map[string]interface{}{
|
||||
"value": setting.Value,
|
||||
@ -74,7 +73,7 @@ func (d *SystemSettingsDAO) UpdateByName(name string, setting *model.SystemSetti
|
||||
|
||||
// Create create a new system setting
|
||||
// Inserts a new system setting record into database
|
||||
func (d *SystemSettingsDAO) Create(setting *model.SystemSettings) error {
|
||||
func (d *SystemSettingsDAO) Create(setting *entity.SystemSettings) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
|
||||
@ -102,7 +101,7 @@ func (d *SystemSettingsDAO) SaveOrCreate(name string, value string, source strin
|
||||
return errors.New("can't update more than 1 setting: " + name)
|
||||
}
|
||||
|
||||
newSetting := &model.SystemSettings{
|
||||
newSetting := &entity.SystemSettings{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Source: source,
|
||||
@ -114,19 +113,19 @@ func (d *SystemSettingsDAO) SaveOrCreate(name string, value string, source strin
|
||||
// Count get total count of system settings
|
||||
func (d *SystemSettingsDAO) Count() (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&model.SystemSettings{}).Count(&count).Error
|
||||
err := DB.Model(&entity.SystemSettings{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteByName delete system setting by name
|
||||
func (d *SystemSettingsDAO) DeleteByName(name string) error {
|
||||
return DB.Where("name = ?", name).Delete(&model.SystemSettings{}).Error
|
||||
return DB.Where("name = ?", name).Delete(&entity.SystemSettings{}).Error
|
||||
}
|
||||
|
||||
// Exists check if setting exists by name
|
||||
func (d *SystemSettingsDAO) Exists(name string) (bool, error) {
|
||||
var count int64
|
||||
err := DB.Model(&model.SystemSettings{}).Where("name = ?", name).Count(&count).Error
|
||||
err := DB.Model(&entity.SystemSettings{}).Where("name = ?", name).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -134,8 +133,8 @@ func (d *SystemSettingsDAO) Exists(name string) (bool, error) {
|
||||
}
|
||||
|
||||
// GetBySource get system settings by source
|
||||
func (d *SystemSettingsDAO) GetBySource(source string) ([]model.SystemSettings, error) {
|
||||
var settings []model.SystemSettings
|
||||
func (d *SystemSettingsDAO) GetBySource(source string) ([]entity.SystemSettings, error) {
|
||||
var settings []entity.SystemSettings
|
||||
err := DB.Where("source = ?", source).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -144,8 +143,8 @@ func (d *SystemSettingsDAO) GetBySource(source string) ([]model.SystemSettings,
|
||||
}
|
||||
|
||||
// GetByDataType get system settings by data type
|
||||
func (d *SystemSettingsDAO) GetByDataType(dataType string) ([]model.SystemSettings, error) {
|
||||
var settings []model.SystemSettings
|
||||
func (d *SystemSettingsDAO) GetByDataType(dataType string) ([]entity.SystemSettings, error) {
|
||||
var settings []entity.SystemSettings
|
||||
err := DB.Where("data_type = ?", dataType).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -159,7 +158,7 @@ func (d *SystemSettingsDAO) Transaction(fn func(tx *gorm.DB) error) error {
|
||||
}
|
||||
|
||||
// CreateWithTx create setting within transaction
|
||||
func (d *SystemSettingsDAO) CreateWithTx(tx *gorm.DB, setting *model.SystemSettings) error {
|
||||
func (d *SystemSettingsDAO) CreateWithTx(tx *gorm.DB, setting *entity.SystemSettings) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
|
||||
@ -172,11 +171,11 @@ func (d *SystemSettingsDAO) CreateWithTx(tx *gorm.DB, setting *model.SystemSetti
|
||||
}
|
||||
|
||||
// UpdateByNameWithTx update setting within transaction
|
||||
func (d *SystemSettingsDAO) UpdateByNameWithTx(tx *gorm.DB, name string, setting *model.SystemSettings) error {
|
||||
func (d *SystemSettingsDAO) UpdateByNameWithTx(tx *gorm.DB, name string, setting *entity.SystemSettings) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
|
||||
return tx.Model(&model.SystemSettings{}).
|
||||
return tx.Model(&entity.SystemSettings{}).
|
||||
Where("name = ?", name).
|
||||
Updates(map[string]interface{}{
|
||||
"value": setting.Value,
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// TaskDAO task data access object
|
||||
@ -29,13 +29,13 @@ func NewTaskDAO() *TaskDAO {
|
||||
}
|
||||
|
||||
// Create creates a new task
|
||||
func (dao *TaskDAO) Create(task *model.Task) error {
|
||||
func (dao *TaskDAO) Create(task *entity.Task) error {
|
||||
return DB.Create(task).Error
|
||||
}
|
||||
|
||||
// GetByID gets task by ID
|
||||
func (dao *TaskDAO) GetByID(id string) (*model.Task, error) {
|
||||
var task model.Task
|
||||
func (dao *TaskDAO) GetByID(id string) (*entity.Task, error) {
|
||||
var task entity.Task
|
||||
err := DB.Where("id = ?", id).First(&task).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -48,12 +48,12 @@ func (dao *TaskDAO) DeleteByDocIDs(docIDs []string) (int64, error) {
|
||||
if len(docIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := DB.Unscoped().Where("doc_id IN ?", docIDs).Delete(&model.Task{})
|
||||
result := DB.Unscoped().Where("doc_id IN ?", docIDs).Delete(&entity.Task{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteByTenantID deletes all tasks by tenant ID (hard delete via document join)
|
||||
func (dao *TaskDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("doc_id IN (SELECT id FROM document WHERE tenant_id = ?)", tenantID).Delete(&model.Task{})
|
||||
result := DB.Unscoped().Where("doc_id IN (SELECT id FROM document WHERE tenant_id = ?)", tenantID).Delete(&entity.Task{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// TenantDAO tenant data access object
|
||||
@ -32,7 +32,7 @@ func NewTenantDAO() *TenantDAO {
|
||||
func (dao *TenantDAO) GetJoinedTenantsByUserID(userID string) ([]*TenantWithRole, error) {
|
||||
var results []*TenantWithRole
|
||||
|
||||
err := DB.Model(&model.Tenant{}).
|
||||
err := DB.Model(&entity.Tenant{}).
|
||||
Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.asr_id, tenant.img2txt_id, user_tenant.role").
|
||||
Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id").
|
||||
Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "normal", "1").
|
||||
@ -70,18 +70,18 @@ type TenantInfo struct {
|
||||
func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) {
|
||||
var results []*TenantInfo
|
||||
|
||||
err := DB.Model(&model.Tenant{}).
|
||||
err := DB.Model(&entity.Tenant{}).
|
||||
Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.parser_ids, user_tenant.role").
|
||||
Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id").
|
||||
Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1").
|
||||
Scan(&results).Error
|
||||
|
||||
|
||||
return results, err
|
||||
}
|
||||
|
||||
// GetByID gets tenant by ID
|
||||
func (dao *TenantDAO) GetByID(id string) (*model.Tenant, error) {
|
||||
var tenant model.Tenant
|
||||
func (dao *TenantDAO) GetByID(id string) (*entity.Tenant, error) {
|
||||
var tenant entity.Tenant
|
||||
err := DB.Where("id = ? AND status = ?", id, "1").First(&tenant).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -90,21 +90,21 @@ func (dao *TenantDAO) GetByID(id string) (*model.Tenant, error) {
|
||||
}
|
||||
|
||||
// Create creates a new tenant
|
||||
func (dao *TenantDAO) Create(tenant *model.Tenant) error {
|
||||
func (dao *TenantDAO) Create(tenant *entity.Tenant) error {
|
||||
return DB.Create(tenant).Error
|
||||
}
|
||||
|
||||
// Delete deletes a tenant by ID (soft delete)
|
||||
func (dao *TenantDAO) Delete(id string) error {
|
||||
return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error
|
||||
return DB.Model(&entity.Tenant{}).Where("id = ?", id).Update("status", "0").Error
|
||||
}
|
||||
|
||||
// Update updates a tenant by ID
|
||||
func (dao *TenantDAO) Update(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&model.Tenant{}).Where("id = ?", id).Updates(updates).Error
|
||||
return DB.Model(&entity.Tenant{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// HardDelete hard deletes a tenant by ID
|
||||
func (dao *TenantDAO) HardDelete(id string) error {
|
||||
return DB.Unscoped().Where("id = ?", id).Delete(&model.Tenant{}).Error
|
||||
return DB.Unscoped().Where("id = ?", id).Delete(&entity.Tenant{}).Error
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// TenantLLMDAO tenant LLM data access object
|
||||
@ -29,8 +29,8 @@ func NewTenantLLMDAO() *TenantLLMDAO {
|
||||
}
|
||||
|
||||
// GetByTenantAndModelName get tenant LLM by tenant ID and model name
|
||||
func (dao *TenantLLMDAO) GetByTenantAndModelName(tenantID, providerName string, modelName string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
func (dao *TenantLLMDAO) GetByTenantAndModelName(tenantID, providerName string, modelName string) (*entity.TenantLLM, error) {
|
||||
var tenantLLM entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, providerName, modelName).First(&tenantLLM).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -39,8 +39,8 @@ func (dao *TenantLLMDAO) GetByTenantAndModelName(tenantID, providerName string,
|
||||
}
|
||||
|
||||
// GetByTenantAndType get tenant LLM by tenant ID and model type
|
||||
func (dao *TenantLLMDAO) GetByTenantAndType(tenantID string, modelType model.ModelType) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
func (dao *TenantLLMDAO) GetByTenantAndType(tenantID string, modelType entity.ModelType) (*entity.TenantLLM, error) {
|
||||
var tenantLLM entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND model_type = ?", tenantID, modelType).First(&tenantLLM).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -49,8 +49,8 @@ func (dao *TenantLLMDAO) GetByTenantAndType(tenantID string, modelType model.Mod
|
||||
}
|
||||
|
||||
// GetByTenantAndFactory get tenant LLM by tenant ID, model type and factory
|
||||
func (dao *TenantLLMDAO) GetByTenantAndFactory(tenantID string, modelType model.ModelType, factory string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
func (dao *TenantLLMDAO) GetByTenantAndFactory(tenantID string, modelType entity.ModelType, factory string) (*entity.TenantLLM, error) {
|
||||
var tenantLLM entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND model_type = ? AND llm_factory = ?", tenantID, modelType, factory).First(&tenantLLM).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -59,8 +59,8 @@ func (dao *TenantLLMDAO) GetByTenantAndFactory(tenantID string, modelType model.
|
||||
}
|
||||
|
||||
// ListByTenant list all tenant LLMs for a tenant
|
||||
func (dao *TenantLLMDAO) ListByTenant(tenantID string) ([]model.TenantLLM, error) {
|
||||
var tenantLLMs []model.TenantLLM
|
||||
func (dao *TenantLLMDAO) ListByTenant(tenantID string) ([]entity.TenantLLM, error) {
|
||||
var tenantLLMs []entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ?", tenantID).Find(&tenantLLMs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -69,8 +69,8 @@ func (dao *TenantLLMDAO) ListByTenant(tenantID string) ([]model.TenantLLM, error
|
||||
}
|
||||
|
||||
// GetByTenantFactoryAndModelName get tenant LLM by tenant ID, factory and model name
|
||||
func (dao *TenantLLMDAO) GetByTenantFactoryAndModelName(tenantID, factory, modelName string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
func (dao *TenantLLMDAO) GetByTenantFactoryAndModelName(tenantID, factory, modelName string) (*entity.TenantLLM, error) {
|
||||
var tenantLLM entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, modelName).First(&tenantLLM).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -79,23 +79,23 @@ func (dao *TenantLLMDAO) GetByTenantFactoryAndModelName(tenantID, factory, model
|
||||
}
|
||||
|
||||
// Create create a new tenant LLM record
|
||||
func (dao *TenantLLMDAO) Create(tenantLLM *model.TenantLLM) error {
|
||||
func (dao *TenantLLMDAO) Create(tenantLLM *entity.TenantLLM) error {
|
||||
return DB.Create(tenantLLM).Error
|
||||
}
|
||||
|
||||
// Update update an existing tenant LLM record
|
||||
func (dao *TenantLLMDAO) Update(tenantLLM *model.TenantLLM) error {
|
||||
func (dao *TenantLLMDAO) Update(tenantLLM *entity.TenantLLM) error {
|
||||
return DB.Save(tenantLLM).Error
|
||||
}
|
||||
|
||||
// Delete delete a tenant LLM record by tenant ID, factory and model name
|
||||
func (dao *TenantLLMDAO) Delete(tenantID, factory, modelName string) error {
|
||||
return DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, modelName).Delete(&model.TenantLLM{}).Error
|
||||
return DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, modelName).Delete(&entity.TenantLLM{}).Error
|
||||
}
|
||||
|
||||
// GetMyLLMs get tenant LLMs with factory details
|
||||
func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]model.MyLLM, error) {
|
||||
var myLLMs []model.MyLLM
|
||||
func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]entity.MyLLM, error) {
|
||||
var myLLMs []entity.MyLLM
|
||||
|
||||
err := DB.Table("tenant_llm tl").
|
||||
Select("tl.id, tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status").
|
||||
@ -109,8 +109,8 @@ func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]model.MyLLM, error) {
|
||||
}
|
||||
|
||||
// ListValidByTenant lists valid tenant LLMs for a tenant
|
||||
func (dao *TenantLLMDAO) ListValidByTenant(tenantID string) ([]*model.TenantLLM, error) {
|
||||
var tenantLLMs []*model.TenantLLM
|
||||
func (dao *TenantLLMDAO) ListValidByTenant(tenantID string) ([]*entity.TenantLLM, error) {
|
||||
var tenantLLMs []*entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND api_key IS NOT NULL AND api_key != ? AND status = ?", tenantID, "", "1").Find(&tenantLLMs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -119,8 +119,8 @@ func (dao *TenantLLMDAO) ListValidByTenant(tenantID string) ([]*model.TenantLLM,
|
||||
}
|
||||
|
||||
// ListAllByTenant lists all tenant LLMs for a tenant
|
||||
func (dao *TenantLLMDAO) ListAllByTenant(tenantID string) ([]*model.TenantLLM, error) {
|
||||
var tenantLLMs []*model.TenantLLM
|
||||
func (dao *TenantLLMDAO) ListAllByTenant(tenantID string) ([]*entity.TenantLLM, error) {
|
||||
var tenantLLMs []*entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ?", tenantID).Find(&tenantLLMs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -129,7 +129,7 @@ func (dao *TenantLLMDAO) ListAllByTenant(tenantID string) ([]*model.TenantLLM, e
|
||||
}
|
||||
|
||||
// InsertMany inserts multiple tenant LLM records
|
||||
func (dao *TenantLLMDAO) InsertMany(tenantLLMs []*model.TenantLLM) error {
|
||||
func (dao *TenantLLMDAO) InsertMany(tenantLLMs []*entity.TenantLLM) error {
|
||||
if len(tenantLLMs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -138,7 +138,7 @@ func (dao *TenantLLMDAO) InsertMany(tenantLLMs []*model.TenantLLM) error {
|
||||
|
||||
// DeleteByTenantID deletes all tenant LLM records by tenant ID (hard delete)
|
||||
func (dao *TenantLLMDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.TenantLLM{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.TenantLLM{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@ -182,7 +182,7 @@ func splitModelNameAndFactory(modelName string) (string, string) {
|
||||
// Validate if factory exists in llm_factories table
|
||||
// This matches Python's logic of checking against model providers
|
||||
var factoryCount int64
|
||||
DB.Model(&model.LLMFactories{}).Where("name = ?", factory).Count(&factoryCount)
|
||||
DB.Model(&entity.LLMFactories{}).Where("name = ?", factory).Count(&factoryCount)
|
||||
|
||||
// If factory doesn't exist in database, treat the whole string as model name
|
||||
if factoryCount == 0 {
|
||||
@ -211,8 +211,8 @@ func splitModelNameAndFactory(modelName string) (string, string) {
|
||||
//
|
||||
// // Model name with factory prefix
|
||||
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4@OpenAI")
|
||||
func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string) (*entity.TenantLLM, error) {
|
||||
var tenantLLM entity.TenantLLM
|
||||
|
||||
// Split model name and factory from the combined format
|
||||
modelName, factory := splitModelNameAndFactory(llmName)
|
||||
@ -260,8 +260,8 @@ func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string
|
||||
// Example:
|
||||
//
|
||||
// tenantLLM, err := dao.GetByTenantIDLLMNameAndFactory("tenant123", "gpt-4", "OpenAI")
|
||||
func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, factory string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, factory string) (*entity.TenantLLM, error) {
|
||||
var tenantLLM entity.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, llmName, factory).First(&tenantLLM).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// TimeRecordDAO time record data access object
|
||||
@ -29,13 +29,13 @@ func NewTimeRecordDAO() *TimeRecordDAO {
|
||||
}
|
||||
|
||||
// Create inserts a new record
|
||||
func (dao *TimeRecordDAO) Create(record *model.TimeRecord) error {
|
||||
func (dao *TimeRecordDAO) Create(record *entity.TimeRecord) error {
|
||||
return DB.Create(record).Error
|
||||
}
|
||||
|
||||
// GetRecent retrieves the most recently inserted records (ordered by ID descending)
|
||||
func (dao *TimeRecordDAO) GetRecent(limit int) ([]*model.TimeRecord, error) {
|
||||
var records []*model.TimeRecord
|
||||
func (dao *TimeRecordDAO) GetRecent(limit int) ([]*entity.TimeRecord, error) {
|
||||
var records []*entity.TimeRecord
|
||||
err := DB.Order("id DESC").Limit(limit).Find(&records).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -46,7 +46,7 @@ func (dao *TimeRecordDAO) GetRecent(limit int) ([]*model.TimeRecord, error) {
|
||||
// GetCount returns the total number of records
|
||||
func (dao *TimeRecordDAO) GetCount() (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&model.TimeRecord{}).Count(&count).Error
|
||||
err := DB.Model(&entity.TimeRecord{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
@ -56,8 +56,8 @@ func (dao *TimeRecordDAO) DeleteOldest(limit int64) error {
|
||||
}
|
||||
|
||||
// GetByID retrieves a single record by its ID
|
||||
func (dao *TimeRecordDAO) GetByID(id int64) (*model.TimeRecord, error) {
|
||||
var record model.TimeRecord
|
||||
func (dao *TimeRecordDAO) GetByID(id int64) (*entity.TimeRecord, error) {
|
||||
var record entity.TimeRecord
|
||||
err := DB.First(&record, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -66,8 +66,8 @@ func (dao *TimeRecordDAO) GetByID(id int64) (*model.TimeRecord, error) {
|
||||
}
|
||||
|
||||
// GetAll retrieves all records
|
||||
func (dao *TimeRecordDAO) GetAll() ([]*model.TimeRecord, error) {
|
||||
var records []*model.TimeRecord
|
||||
func (dao *TimeRecordDAO) GetAll() ([]*entity.TimeRecord, error) {
|
||||
var records []*entity.TimeRecord
|
||||
err := DB.Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
@ -76,7 +76,7 @@ func (dao *TimeRecordDAO) GetAll() ([]*model.TimeRecord, error) {
|
||||
func (dao *TimeRecordDAO) KeepLatest(count int64) error {
|
||||
// Step 1: Get the maximum ID
|
||||
var maxID int64
|
||||
if err := DB.Model(&model.TimeRecord{}).Select("COALESCE(MAX(id), 0)").Scan(&maxID).Error; err != nil {
|
||||
if err := DB.Model(&entity.TimeRecord{}).Select("COALESCE(MAX(id), 0)").Scan(&maxID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -94,10 +94,10 @@ func (dao *TimeRecordDAO) KeepLatest(count int64) error {
|
||||
}
|
||||
|
||||
// Step 3: Delete records with ID <= threshold
|
||||
return DB.Where("id <= ?", thresholdID).Delete(&model.TimeRecord{}).Error
|
||||
return DB.Where("id <= ?", thresholdID).Delete(&entity.TimeRecord{}).Error
|
||||
}
|
||||
|
||||
// DeleteAll deletes all records
|
||||
func (dao *TimeRecordDAO) DeleteAll() error {
|
||||
return DB.Where("1=1").Delete(&model.TimeRecord{}).Error
|
||||
return DB.Where("1=1").Delete(&entity.TimeRecord{}).Error
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// UserDAO user data access object
|
||||
@ -29,13 +29,13 @@ func NewUserDAO() *UserDAO {
|
||||
}
|
||||
|
||||
// Create create user
|
||||
func (dao *UserDAO) Create(user *model.User) error {
|
||||
func (dao *UserDAO) Create(user *entity.User) error {
|
||||
return DB.Create(user).Error
|
||||
}
|
||||
|
||||
// GetByID get user by ID
|
||||
func (dao *UserDAO) GetByID(id uint) (*model.User, error) {
|
||||
var user model.User
|
||||
func (dao *UserDAO) GetByID(id uint) (*entity.User, error) {
|
||||
var user entity.User
|
||||
err := DB.First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -43,8 +43,8 @@ func (dao *UserDAO) GetByID(id uint) (*model.User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (dao *UserDAO) GetByTenantID(tenantID string) (*model.User, error) {
|
||||
var user model.User
|
||||
func (dao *UserDAO) GetByTenantID(tenantID string) (*entity.User, error) {
|
||||
var user entity.User
|
||||
err := DB.Where("id = ?", tenantID).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -53,8 +53,8 @@ func (dao *UserDAO) GetByTenantID(tenantID string) (*model.User, error) {
|
||||
}
|
||||
|
||||
// GetByEmail get user by email
|
||||
func (dao *UserDAO) GetByEmail(email string) (*model.User, error) {
|
||||
var user model.User
|
||||
func (dao *UserDAO) GetByEmail(email string) (*entity.User, error) {
|
||||
var user entity.User
|
||||
query := DB.Where("email = ?", email)
|
||||
err := query.First(&user).Error
|
||||
if err != nil {
|
||||
@ -64,8 +64,8 @@ func (dao *UserDAO) GetByEmail(email string) (*model.User, error) {
|
||||
}
|
||||
|
||||
// GetByAccessToken get user by access token
|
||||
func (dao *UserDAO) GetByAccessToken(token string) (*model.User, error) {
|
||||
var user model.User
|
||||
func (dao *UserDAO) GetByAccessToken(token string) (*entity.User, error) {
|
||||
var user entity.User
|
||||
err := DB.Where("access_token = ?", token).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -74,26 +74,26 @@ func (dao *UserDAO) GetByAccessToken(token string) (*model.User, error) {
|
||||
}
|
||||
|
||||
// Update update user
|
||||
func (dao *UserDAO) Update(user *model.User) error {
|
||||
func (dao *UserDAO) Update(user *entity.User) error {
|
||||
return DB.Save(user).Error
|
||||
}
|
||||
|
||||
// UpdateAccessToken update user's access token
|
||||
func (dao *UserDAO) UpdateAccessToken(user *model.User, token string) error {
|
||||
func (dao *UserDAO) UpdateAccessToken(user *entity.User, token string) error {
|
||||
return DB.Model(user).Update("access_token", token).Error
|
||||
}
|
||||
|
||||
// List list users (only active users with status != "0")
|
||||
func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) {
|
||||
var users []*model.User
|
||||
func (dao *UserDAO) List(offset, limit int) ([]*entity.User, int64, error) {
|
||||
var users []*entity.User
|
||||
var total int64
|
||||
|
||||
// Only count users with status != "0" (not deleted)
|
||||
if err := DB.Model(&model.User{}).Count(&total).Error; err != nil {
|
||||
if err := DB.Model(&entity.User{}).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
query := DB.Model(&model.User{})
|
||||
query := DB.Model(&entity.User{})
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
@ -106,23 +106,23 @@ func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) {
|
||||
|
||||
// Delete delete user
|
||||
func (dao *UserDAO) Delete(id uint) error {
|
||||
return DB.Delete(&model.User{}, id).Error
|
||||
return DB.Delete(&entity.User{}, id).Error
|
||||
}
|
||||
|
||||
// DeleteByID delete user by string ID (soft delete - set status to 0)
|
||||
func (dao *UserDAO) DeleteByID(id string) error {
|
||||
return DB.Model(&model.User{}).Where("id = ?", id).Update("status", "0").Error
|
||||
return DB.Model(&entity.User{}).Where("id = ?", id).Update("status", "0").Error
|
||||
}
|
||||
|
||||
// HardDelete hard delete user by string ID
|
||||
func (dao *UserDAO) HardDelete(id string) error {
|
||||
return DB.Unscoped().Where("id = ?", id).Delete(&model.User{}).Error
|
||||
return DB.Unscoped().Where("id = ?", id).Delete(&entity.User{}).Error
|
||||
}
|
||||
|
||||
// ListByEmail list users by email (only active users with status != "0")
|
||||
// Returns all users matching the given email address
|
||||
func (dao *UserDAO) ListByEmail(email string) ([]*model.User, error) {
|
||||
var users []*model.User
|
||||
func (dao *UserDAO) ListByEmail(email string) ([]*entity.User, error) {
|
||||
var users []*entity.User
|
||||
err := DB.Where("email = ?", email).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// UserCanvasDAO user canvas data access object
|
||||
@ -29,13 +29,13 @@ func NewUserCanvasDAO() *UserCanvasDAO {
|
||||
}
|
||||
|
||||
// Create user canvas
|
||||
func (dao *UserCanvasDAO) Create(userCanvas *model.UserCanvas) error {
|
||||
func (dao *UserCanvasDAO) Create(userCanvas *entity.UserCanvas) error {
|
||||
return DB.Create(userCanvas).Error
|
||||
}
|
||||
|
||||
// GetByID get user canvas by ID
|
||||
func (dao *UserCanvasDAO) GetByID(id string) (*model.UserCanvas, error) {
|
||||
var canvas model.UserCanvas
|
||||
func (dao *UserCanvasDAO) GetByID(id string) (*entity.UserCanvas, error) {
|
||||
var canvas entity.UserCanvas
|
||||
err := DB.Where("id = ?", id).First(&canvas).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -44,13 +44,13 @@ func (dao *UserCanvasDAO) GetByID(id string) (*model.UserCanvas, error) {
|
||||
}
|
||||
|
||||
// Update update user canvas
|
||||
func (dao *UserCanvasDAO) Update(userCanvas *model.UserCanvas) error {
|
||||
func (dao *UserCanvasDAO) Update(userCanvas *entity.UserCanvas) error {
|
||||
return DB.Save(userCanvas).Error
|
||||
}
|
||||
|
||||
// Delete delete user canvas
|
||||
func (dao *UserCanvasDAO) Delete(id string) error {
|
||||
return DB.Delete(&model.UserCanvas{}, id).Error
|
||||
return DB.Delete(&entity.UserCanvas{}, id).Error
|
||||
}
|
||||
|
||||
// GetList get canvases list with pagination and filtering
|
||||
@ -62,9 +62,9 @@ func (dao *UserCanvasDAO) GetList(
|
||||
desc bool,
|
||||
id, title string,
|
||||
canvasCategory string,
|
||||
) ([]*model.UserCanvas, error) {
|
||||
) ([]*entity.UserCanvas, error) {
|
||||
|
||||
query := DB.Model(&model.UserCanvas{}).
|
||||
query := DB.Model(&entity.UserCanvas{}).
|
||||
Where("user_id = ?", tenantID)
|
||||
|
||||
if id != "" {
|
||||
@ -93,7 +93,7 @@ func (dao *UserCanvasDAO) GetList(
|
||||
query = query.Offset(offset).Limit(itemsPerPage)
|
||||
}
|
||||
|
||||
var canvases []*model.UserCanvas
|
||||
var canvases []*entity.UserCanvas
|
||||
err := query.Find(&canvases).Error
|
||||
return canvases, err
|
||||
}
|
||||
@ -102,7 +102,7 @@ func (dao *UserCanvasDAO) GetList(
|
||||
// Similar to Python UserCanvasService.get_all_agents_by_tenant_ids
|
||||
func (dao *UserCanvasDAO) GetAllCanvasesByTenantIDs(tenantIDs []string, userID string) ([]*CanvasBasicInfo, error) {
|
||||
|
||||
query := DB.Model(&model.UserCanvas{}).
|
||||
query := DB.Model(&entity.UserCanvas{}).
|
||||
Select("id, avatar, title, permission, canvas_type, canvas_category").
|
||||
Where("user_id IN (?) AND permission = ?", tenantIDs, "team").
|
||||
Or("user_id = ?", userID).
|
||||
@ -114,7 +114,7 @@ func (dao *UserCanvasDAO) GetAllCanvasesByTenantIDs(tenantIDs []string, userID s
|
||||
}
|
||||
|
||||
// GetByCanvasID get user canvas by canvas ID (alias for GetByID)
|
||||
func (dao *UserCanvasDAO) GetByCanvasID(canvasID string) (*model.UserCanvas, error) {
|
||||
func (dao *UserCanvasDAO) GetByCanvasID(canvasID string) (*entity.UserCanvas, error) {
|
||||
return dao.GetByID(canvasID)
|
||||
}
|
||||
|
||||
@ -130,14 +130,14 @@ type CanvasBasicInfo struct {
|
||||
|
||||
// DeleteByUserID deletes all canvases by user ID (hard delete)
|
||||
func (dao *UserCanvasDAO) DeleteByUserID(userID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("user_id = ?", userID).Delete(&model.UserCanvas{})
|
||||
result := DB.Unscoped().Where("user_id = ?", userID).Delete(&entity.UserCanvas{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetAllCanvasIDsByUserID gets all canvas IDs by user ID
|
||||
func (dao *UserCanvasDAO) GetAllCanvasIDsByUserID(userID string) ([]string, error) {
|
||||
var canvasIDs []string
|
||||
err := DB.Model(&model.UserCanvas{}).
|
||||
err := DB.Model(&entity.UserCanvas{}).
|
||||
Where("user_id = ?", userID).
|
||||
Pluck("id", &canvasIDs).Error
|
||||
return canvasIDs, err
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// UserTenantDAO user tenant data access object
|
||||
@ -29,13 +29,13 @@ func NewUserTenantDAO() *UserTenantDAO {
|
||||
}
|
||||
|
||||
// Create create user tenant relationship
|
||||
func (dao *UserTenantDAO) Create(userTenant *model.UserTenant) error {
|
||||
func (dao *UserTenantDAO) Create(userTenant *entity.UserTenant) error {
|
||||
return DB.Create(userTenant).Error
|
||||
}
|
||||
|
||||
// GetByID get user tenant relationship by ID
|
||||
func (dao *UserTenantDAO) GetByID(id string) (*model.UserTenant, error) {
|
||||
var userTenant model.UserTenant
|
||||
func (dao *UserTenantDAO) GetByID(id string) (*entity.UserTenant, error) {
|
||||
var userTenant entity.UserTenant
|
||||
err := DB.Where("id = ? AND status = ?", id, "1").First(&userTenant).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -44,25 +44,25 @@ func (dao *UserTenantDAO) GetByID(id string) (*model.UserTenant, error) {
|
||||
}
|
||||
|
||||
// Update update user tenant relationship
|
||||
func (dao *UserTenantDAO) Update(userTenant *model.UserTenant) error {
|
||||
func (dao *UserTenantDAO) Update(userTenant *entity.UserTenant) error {
|
||||
return DB.Save(userTenant).Error
|
||||
}
|
||||
|
||||
// Delete delete user tenant relationship (soft delete by setting status to "0")
|
||||
func (dao *UserTenantDAO) Delete(id string) error {
|
||||
return DB.Model(&model.UserTenant{}).Where("id = ?", id).Update("status", "0").Error
|
||||
return DB.Model(&entity.UserTenant{}).Where("id = ?", id).Update("status", "0").Error
|
||||
}
|
||||
|
||||
// GetByUserID get user tenant relationships by user ID
|
||||
func (dao *UserTenantDAO) GetByUserID(userID string) ([]*model.UserTenant, error) {
|
||||
var relations []*model.UserTenant
|
||||
func (dao *UserTenantDAO) GetByUserID(userID string) ([]*entity.UserTenant, error) {
|
||||
var relations []*entity.UserTenant
|
||||
err := DB.Where("user_id = ? AND status = ?", userID, "1").Find(&relations).Error
|
||||
return relations, err
|
||||
}
|
||||
|
||||
// GetByTenantID get user tenant relationships by tenant ID
|
||||
func (dao *UserTenantDAO) GetByTenantID(tenantID string) ([]*model.UserTenant, error) {
|
||||
var relations []*model.UserTenant
|
||||
func (dao *UserTenantDAO) GetByTenantID(tenantID string) ([]*entity.UserTenant, error) {
|
||||
var relations []*entity.UserTenant
|
||||
err := DB.Where("tenant_id = ? AND status = ?", tenantID, "1").Find(&relations).Error
|
||||
return relations, err
|
||||
}
|
||||
@ -70,7 +70,7 @@ func (dao *UserTenantDAO) GetByTenantID(tenantID string) ([]*model.UserTenant, e
|
||||
// GetTenantIDsByUserID get tenant ID list by user ID
|
||||
func (dao *UserTenantDAO) GetTenantIDsByUserID(userID string) ([]string, error) {
|
||||
var tenantIDs []string
|
||||
err := DB.Model(&model.UserTenant{}).
|
||||
err := DB.Model(&entity.UserTenant{}).
|
||||
Select("tenant_id").
|
||||
Where("user_id = ? AND status = ?", userID, "1").
|
||||
Pluck("tenant_id", &tenantIDs).Error
|
||||
@ -78,8 +78,8 @@ func (dao *UserTenantDAO) GetTenantIDsByUserID(userID string) ([]string, error)
|
||||
}
|
||||
|
||||
// FilterByUserIDAndTenantID filter user tenant relationship by user ID and tenant ID
|
||||
func (dao *UserTenantDAO) FilterByUserIDAndTenantID(userID, tenantID string) (*model.UserTenant, error) {
|
||||
var userTenant model.UserTenant
|
||||
func (dao *UserTenantDAO) FilterByUserIDAndTenantID(userID, tenantID string) (*entity.UserTenant, error) {
|
||||
var userTenant entity.UserTenant
|
||||
err := DB.Where("user_id = ? AND tenant_id = ? AND status = ?", userID, tenantID, "1").
|
||||
First(&userTenant).Error
|
||||
if err != nil {
|
||||
@ -89,8 +89,8 @@ func (dao *UserTenantDAO) FilterByUserIDAndTenantID(userID, tenantID string) (*m
|
||||
}
|
||||
|
||||
// GetByUserIDAndRole get user tenant relationships by user ID and role
|
||||
func (dao *UserTenantDAO) GetByUserIDAndRole(userID, role string) ([]*model.UserTenant, error) {
|
||||
var relations []*model.UserTenant
|
||||
func (dao *UserTenantDAO) GetByUserIDAndRole(userID, role string) ([]*entity.UserTenant, error) {
|
||||
var relations []*entity.UserTenant
|
||||
err := DB.Where("user_id = ? AND role = ? AND status = ?", userID, role, "1").Find(&relations).Error
|
||||
return relations, err
|
||||
}
|
||||
@ -98,7 +98,7 @@ func (dao *UserTenantDAO) GetByUserIDAndRole(userID, role string) ([]*model.User
|
||||
// GetNumMembers get number of members in a tenant (excluding owner)
|
||||
func (dao *UserTenantDAO) GetNumMembers(tenantID string) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&model.UserTenant{}).
|
||||
err := DB.Model(&entity.UserTenant{}).
|
||||
Where("tenant_id = ? AND status = ? AND role != ?", tenantID, "1", "owner").
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
@ -127,19 +127,19 @@ func (dao *UserTenantDAO) GetTenantsByUserID(userID string) ([]*TenantInfoByUser
|
||||
|
||||
// DeleteByUserID delete user tenant relationships by user ID (hard delete)
|
||||
func (dao *UserTenantDAO) DeleteByUserID(userID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("user_id = ?", userID).Delete(&model.UserTenant{})
|
||||
result := DB.Unscoped().Where("user_id = ?", userID).Delete(&entity.UserTenant{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteByTenantID delete user tenant relationships by tenant ID (hard delete)
|
||||
func (dao *UserTenantDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.UserTenant{})
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.UserTenant{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetByUserIDAll get all user tenant relationships by user ID (including deleted)
|
||||
func (dao *UserTenantDAO) GetByUserIDAll(userID string) ([]*model.UserTenant, error) {
|
||||
var relations []*model.UserTenant
|
||||
func (dao *UserTenantDAO) GetByUserIDAll(userID string) ([]*entity.UserTenant, error) {
|
||||
var relations []*entity.UserTenant
|
||||
err := DB.Where("user_id = ?", userID).Find(&relations).Error
|
||||
return relations, err
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// APIToken API token model
|
||||
type APIToken struct {
|
||||
@ -33,20 +33,20 @@ func (APIToken) TableName() string {
|
||||
|
||||
// API4Conversation API for conversation model
|
||||
type API4Conversation struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Name *string `gorm:"column:name;size:255" json:"name,omitempty"`
|
||||
DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"`
|
||||
UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"`
|
||||
ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"`
|
||||
Message JSONMap `gorm:"column:message;type:longtext" json:"message,omitempty"`
|
||||
Reference JSONMap `gorm:"column:reference;type:longtext" json:"reference"`
|
||||
Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"`
|
||||
Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"`
|
||||
DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"`
|
||||
Duration float64 `gorm:"column:duration;default:0;index" json:"duration"`
|
||||
Round int64 `gorm:"column:round;default:0;index" json:"round"`
|
||||
ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"`
|
||||
Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"`
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Name *string `gorm:"column:name;size:255" json:"name,omitempty"`
|
||||
DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"`
|
||||
UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"`
|
||||
ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"`
|
||||
Message JSONMap `gorm:"column:message;type:longtext" json:"message,omitempty"`
|
||||
Reference JSONMap `gorm:"column:reference;type:longtext" json:"reference"`
|
||||
Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"`
|
||||
Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"`
|
||||
DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"`
|
||||
Duration float64 `gorm:"column:duration;default:0;index" json:"duration"`
|
||||
Round int64 `gorm:"column:round;default:0;index" json:"round"`
|
||||
ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"`
|
||||
Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// UserCanvas user canvas model
|
||||
type UserCanvas struct {
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
@ -14,17 +14,17 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// EvaluationDataset evaluation dataset model
|
||||
// Note: Python defines custom create_time/update_time (not null) instead of using BaseModel's
|
||||
type EvaluationDataset struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"`
|
||||
Name string `gorm:"column:name;size:255;not null;index" json:"name"`
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"`
|
||||
Name string `gorm:"column:name;size:255;not null;index" json:"name"`
|
||||
Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"`
|
||||
KbIDs JSONMap `gorm:"column:kb_ids;type:longtext;not null" json:"kb_ids"`
|
||||
CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"`
|
||||
CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"`
|
||||
// Custom time fields (not null) to match Python
|
||||
CreateTime int64 `gorm:"column:create_time;not null;index" json:"create_time"`
|
||||
UpdateTime int64 `gorm:"column:update_time;not null" json:"update_time"`
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// File file model
|
||||
type File struct {
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// LLMFactories LLM factory model
|
||||
type LLMFactories struct {
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// MCPServer MCP server model
|
||||
type MCPServer struct {
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// Memory memory model
|
||||
type Memory struct {
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// Search search model
|
||||
type Search struct {
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
45
internal/entity/tenant.go
Normal file
45
internal/entity/tenant.go
Normal file
@ -0,0 +1,45 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package entity
|
||||
|
||||
// Tenant tenant model
|
||||
type Tenant struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Name *string `gorm:"column:name;size:100;index" json:"name,omitempty"`
|
||||
PublicKey *string `gorm:"column:public_key;size:255;index" json:"public_key,omitempty"`
|
||||
LLMID string `gorm:"column:llm_id;size:128;not null;index" json:"llm_id"`
|
||||
TenantLLMID *int64 `gorm:"column:tenant_llm_id;index" json:"tenant_llm_id,omitempty"`
|
||||
EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"`
|
||||
TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"`
|
||||
ASRID string `gorm:"column:asr_id;size:128;not null;index" json:"asr_id"`
|
||||
TenantASRID *int64 `gorm:"column:tenant_asr_id;index" json:"tenant_asr_id,omitempty"`
|
||||
Img2TxtID string `gorm:"column:img2txt_id;size:128;not null;index" json:"img2txt_id"`
|
||||
TenantImg2TxtID *int64 `gorm:"column:tenant_img2txt_id;index" json:"tenant_img2txt_id,omitempty"`
|
||||
RerankID string `gorm:"column:rerank_id;size:128;not null;index" json:"rerank_id"`
|
||||
TenantRerankID *int64 `gorm:"column:tenant_rerank_id;index" json:"tenant_rerank_id,omitempty"`
|
||||
TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"`
|
||||
TenantTTSID *int64 `gorm:"column:tenant_tts_id;index" json:"tenant_tts_id,omitempty"`
|
||||
ParserIDs string `gorm:"column:parser_ids;size:256;not null;index" json:"parser_ids"`
|
||||
Credit int64 `gorm:"column:credit;default:512;index" json:"credit"`
|
||||
Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
// TableName specify table name
|
||||
func (Tenant) TableName() string {
|
||||
return "tenant"
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// TenantLLM tenant LLM model
|
||||
// Python uses PrimaryKeyField (auto-increment ID) with unique index on (tenant_id, llm_factory, llm_name)
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// ModelType represents the type of model
|
||||
type ModelType string
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
import "time"
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
package entity
|
||||
|
||||
// UserTenant user tenant relationship model
|
||||
type UserTenant struct {
|
||||
@ -19,7 +19,8 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -45,7 +46,7 @@ func (h *SystemHandler) ListTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
userModel, ok := user.(*model.User)
|
||||
userModel, ok := user.(*entity.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
@ -105,7 +106,7 @@ func (h *SystemHandler) CreateToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
userModel, ok := user.(*model.User)
|
||||
userModel, ok := user.(*entity.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
@ -175,7 +176,7 @@ func (h *SystemHandler) DeleteToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
userModel, ok := user.(*model.User)
|
||||
userModel, ok := user.(*entity.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
|
||||
@ -18,18 +18,18 @@ package handler
|
||||
|
||||
import (
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetUser(c *gin.Context) (*model.User, common.ErrorCode, string) {
|
||||
func GetUser(c *gin.Context) (*entity.User, common.ErrorCode, string) {
|
||||
userAny, exist := c.Get("user")
|
||||
if !exist {
|
||||
return nil, common.CodeUnauthorized, "User not found"
|
||||
}
|
||||
|
||||
user, ok := userAny.(*model.User)
|
||||
user, ok := userAny.(*entity.User)
|
||||
if !ok {
|
||||
return nil, common.CodeUnauthorized, "User not found"
|
||||
}
|
||||
|
||||
@ -1,45 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package model
|
||||
|
||||
// Tenant tenant model
|
||||
type Tenant struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Name *string `gorm:"column:name;size:100;index" json:"name,omitempty"`
|
||||
PublicKey *string `gorm:"column:public_key;size:255;index" json:"public_key,omitempty"`
|
||||
LLMID string `gorm:"column:llm_id;size:128;not null;index" json:"llm_id"`
|
||||
TenantLLMID *int64 `gorm:"column:tenant_llm_id;index" json:"tenant_llm_id,omitempty"`
|
||||
EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"`
|
||||
TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"`
|
||||
ASRID string `gorm:"column:asr_id;size:128;not null;index" json:"asr_id"`
|
||||
TenantASRID *int64 `gorm:"column:tenant_asr_id;index" json:"tenant_asr_id,omitempty"`
|
||||
Img2TxtID string `gorm:"column:img2txt_id;size:128;not null;index" json:"img2txt_id"`
|
||||
TenantImg2TxtID *int64 `gorm:"column:tenant_img2txt_id;index" json:"tenant_img2txt_id,omitempty"`
|
||||
RerankID string `gorm:"column:rerank_id;size:128;not null;index" json:"rerank_id"`
|
||||
TenantRerankID *int64 `gorm:"column:tenant_rerank_id;index" json:"tenant_rerank_id,omitempty"`
|
||||
TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"`
|
||||
TenantTTSID *int64 `gorm:"column:tenant_tts_id;index" json:"tenant_tts_id,omitempty"`
|
||||
ParserIDs string `gorm:"column:parser_ids;size:256;not null;index" json:"parser_ids"`
|
||||
Credit int64 `gorm:"column:credit;default:512;index" json:"credit"`
|
||||
Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
// TableName specify table name
|
||||
func (Tenant) TableName() string {
|
||||
return "tenant"
|
||||
}
|
||||
@ -456,7 +456,7 @@ func FromEnvironments() error {
|
||||
}
|
||||
|
||||
minioPort := strings.ToLower(os.Getenv("MINIO_PORT"))
|
||||
println(fmt.Sprintf("MINIO ip and port from env: %s:%s", minioIP, minioPort))
|
||||
// println(fmt.Sprintf("MINIO ip and port from env: %s:%s", minioIP, minioPort))
|
||||
if minioPort != "" {
|
||||
ip, _, err := net.SplitHostPort(globalConfig.StorageEngine.Minio.Host)
|
||||
if err != nil {
|
||||
|
||||
@ -18,7 +18,7 @@ package service
|
||||
|
||||
import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/utility"
|
||||
"time"
|
||||
)
|
||||
@ -76,7 +76,7 @@ func (s *SystemService) CreateAPIToken(tenantID string, req *CreateAPITokenReque
|
||||
// beta: generate_confirmation_token().replace("ragflow-", "")[:32]
|
||||
betaAPIKey := utility.GenerateBetaAPIToken(APIToken)
|
||||
|
||||
APITokenData := &model.APIToken{
|
||||
APITokenData := &entity.APIToken{
|
||||
TenantID: tenantID,
|
||||
Token: APIToken,
|
||||
Beta: &betaAPIKey,
|
||||
|
||||
@ -19,6 +19,7 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@ -26,7 +27,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// ChatService chat service
|
||||
@ -49,7 +49,7 @@ func NewChatService() *ChatService {
|
||||
|
||||
// ChatWithKBNames chat with knowledge base names
|
||||
type ChatWithKBNames struct {
|
||||
*model.Chat
|
||||
*entity.Chat
|
||||
KBNames []string `json:"kb_names"`
|
||||
}
|
||||
|
||||
@ -109,7 +109,7 @@ type ListChatsNextResponse struct {
|
||||
|
||||
// ListChatsNext list chats with advanced filtering (equivalent to list_dialogs_next)
|
||||
func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListChatsNextResponse, error) {
|
||||
var chats []*model.Chat
|
||||
var chats []*entity.Chat
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
@ -142,7 +142,7 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
|
||||
}
|
||||
chats = chats[start:end]
|
||||
} else {
|
||||
chats = []*model.Chat{}
|
||||
chats = []*entity.Chat{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -164,7 +164,7 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
|
||||
}
|
||||
|
||||
// getKBNames gets knowledge base names by IDs
|
||||
func (s *ChatService) getKBNames(kbIDs model.JSONSlice) []string {
|
||||
func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string {
|
||||
var names []string
|
||||
for _, kbID := range kbIDs {
|
||||
kbIDStr, ok := kbID.(string)
|
||||
@ -225,7 +225,7 @@ type SetDialogRequest struct {
|
||||
|
||||
// SetDialogResponse set chat response
|
||||
type SetDialogResponse struct {
|
||||
*model.Chat
|
||||
*entity.Chat
|
||||
KBNames []string `json:"kb_names"`
|
||||
}
|
||||
|
||||
@ -393,7 +393,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
|
||||
}
|
||||
|
||||
// Convert prompt config to JSONMap with all fields
|
||||
promptConfigMap := model.JSONMap{
|
||||
promptConfigMap := entity.JSONMap{
|
||||
"system": promptConfig.System,
|
||||
"prologue": promptConfig.Prologue,
|
||||
"empty_response": promptConfig.EmptyResponse,
|
||||
@ -420,7 +420,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
|
||||
}
|
||||
|
||||
// Convert kbIDs to JSONSlice
|
||||
kbIDsJSON := make(model.JSONSlice, len(kbIDs))
|
||||
kbIDsJSON := make(entity.JSONSlice, len(kbIDs))
|
||||
for i, id := range kbIDs {
|
||||
kbIDsJSON[i] = id
|
||||
}
|
||||
@ -441,7 +441,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
|
||||
language := "English"
|
||||
|
||||
// Create new chat
|
||||
chat := &model.Chat{
|
||||
chat := &entity.Chat{
|
||||
ID: newID,
|
||||
TenantID: tenantID,
|
||||
Name: &name,
|
||||
@ -451,7 +451,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
|
||||
LLMID: llmID,
|
||||
LLMSetting: llmSetting,
|
||||
PromptConfig: promptConfigMap,
|
||||
MetaDataFilter: (*model.JSONMap)(&metaDataFilter),
|
||||
MetaDataFilter: (*entity.JSONMap)(&metaDataFilter),
|
||||
TopN: topN,
|
||||
TopK: topK,
|
||||
RerankID: rerankID,
|
||||
@ -558,7 +558,7 @@ func (s *ChatService) splitModelNameAndFactory(embdID string) string {
|
||||
}
|
||||
|
||||
// getEmbdIDs extracts embedding IDs from knowledge bases
|
||||
func getEmbdIDs(kbs []*model.Knowledgebase) []string {
|
||||
func getEmbdIDs(kbs []*entity.Knowledgebase) []string {
|
||||
ids := make([]string, len(kbs))
|
||||
for i, kb := range kbs {
|
||||
ids[i] = kb.EmbdID
|
||||
|
||||
@ -26,7 +26,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// ChatSessionService chat session (conversation) service
|
||||
@ -55,7 +55,7 @@ type SetChatSessionRequest struct {
|
||||
|
||||
// SetChatSessionResponse set chat session response
|
||||
type SetChatSessionResponse struct {
|
||||
*model.ChatSession
|
||||
*entity.ChatSession
|
||||
}
|
||||
|
||||
// SetChatSession create or update a chat session
|
||||
@ -131,7 +131,7 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe
|
||||
referenceJSON, _ := json.Marshal([]interface{}{})
|
||||
|
||||
// Create chat session
|
||||
session := &model.ChatSession{
|
||||
session := &entity.ChatSession{
|
||||
ID: newID,
|
||||
DialogID: req.DialogID,
|
||||
Name: &name,
|
||||
@ -212,7 +212,7 @@ type ListChatSessionsRequest struct {
|
||||
|
||||
// ListChatSessionsResponse list chat sessions response
|
||||
type ListChatSessionsResponse struct {
|
||||
Sessions []*model.ChatSession
|
||||
Sessions []*entity.ChatSession
|
||||
}
|
||||
|
||||
// ListChatSessions lists chat sessions for a dialog
|
||||
@ -397,7 +397,7 @@ func (s *ChatSessionService) CompletionStream(userID string, conversationID stri
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (s *ChatSessionService) buildSessionMessages(session *model.ChatSession, messages []map[string]interface{}) []map[string]interface{} {
|
||||
func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} {
|
||||
// Deep copy messages to session
|
||||
sessionMessages := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
@ -409,7 +409,7 @@ func (s *ChatSessionService) buildSessionMessages(session *model.ChatSession, me
|
||||
return sessionMessages
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) initializeReference(session *model.ChatSession) []interface{} {
|
||||
func (s *ChatSessionService) initializeReference(session *entity.ChatSession) []interface{} {
|
||||
var reference []interface{}
|
||||
if len(session.Reference) > 0 {
|
||||
json.Unmarshal(session.Reference, &reference)
|
||||
@ -433,7 +433,7 @@ func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (b
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) performChat(dialog *model.Chat, messages []map[string]interface{}, config map[string]interface{}) (string, error) {
|
||||
func (s *ChatSessionService) performChat(dialog *entity.Chat, messages []map[string]interface{}, config map[string]interface{}) (string, error) {
|
||||
// Get system prompt from dialog
|
||||
systemPrompt := ""
|
||||
if dialog.PromptConfig != nil {
|
||||
@ -456,7 +456,7 @@ func (s *ChatSessionService) performChat(dialog *model.Chat, messages []map[stri
|
||||
}
|
||||
|
||||
// Use ModelBundle to perform chat
|
||||
bundle, err := NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID)
|
||||
bundle, err := NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -476,7 +476,7 @@ func (s *ChatSessionService) performChat(dialog *model.Chat, messages []map[stri
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) performChatStream(dialog *model.Chat, messages []map[string]interface{}, config map[string]interface{}) (<-chan string, error) {
|
||||
func (s *ChatSessionService) performChatStream(dialog *entity.Chat, messages []map[string]interface{}, config map[string]interface{}) (<-chan string, error) {
|
||||
// Get system prompt from dialog
|
||||
systemPrompt := ""
|
||||
if dialog.PromptConfig != nil {
|
||||
@ -499,7 +499,7 @@ func (s *ChatSessionService) performChatStream(dialog *model.Chat, messages []ma
|
||||
}
|
||||
|
||||
// Use ModelBundle to perform streaming chat
|
||||
bundle, err := NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID)
|
||||
bundle, err := NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -516,7 +516,7 @@ func (s *ChatSessionService) performChatStream(dialog *model.Chat, messages []ma
|
||||
}
|
||||
|
||||
// Get chat model and call ChatStreamly
|
||||
chatModel, ok := bundle.GetModel().(model.ChatModel)
|
||||
chatModel, ok := bundle.GetModel().(entity.ChatModel)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model is not a chat model")
|
||||
}
|
||||
@ -524,7 +524,7 @@ func (s *ChatSessionService) performChatStream(dialog *model.Chat, messages []ma
|
||||
return chatModel.ChatStreamly(systemPrompt, history, genConf)
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) structureAnswer(session *model.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} {
|
||||
func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"answer": answer,
|
||||
"reference": reference,
|
||||
@ -533,7 +533,7 @@ func (s *ChatSessionService) structureAnswer(session *model.ChatSession, answer
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) updateSessionMessages(session *model.ChatSession, messages []map[string]interface{}, reference []interface{}) {
|
||||
func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) {
|
||||
// Update session with new messages and reference
|
||||
messagesJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": messages,
|
||||
@ -550,7 +550,7 @@ func (s *ChatSessionService) updateSessionMessages(session *model.ChatSession, m
|
||||
}
|
||||
|
||||
// asyncChat performs chat with RAG support (non-streaming)
|
||||
func (s *ChatSessionService) asyncChat(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
// Check if we need RAG (knowledge base or tavily)
|
||||
hasKB := len(dialog.KBIDs) > 0
|
||||
hasTavily := false
|
||||
@ -579,7 +579,7 @@ func (s *ChatSessionService) asyncChat(dialog *model.Chat, session *model.ChatSe
|
||||
}
|
||||
|
||||
// asyncChatStream performs streaming chat with RAG support
|
||||
func (s *ChatSessionService) asyncChatStream(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
|
||||
func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
|
||||
resultChan := make(chan map[string]interface{})
|
||||
|
||||
go func() {
|
||||
@ -609,7 +609,7 @@ func (s *ChatSessionService) asyncChatStream(dialog *model.Chat, session *model.
|
||||
}
|
||||
|
||||
// asyncChatSolo performs simple chat without RAG (non-streaming)
|
||||
func (s *ChatSessionService) asyncChatSolo(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
// Get system prompt
|
||||
systemPrompt := s.buildSystemPrompt(dialog)
|
||||
|
||||
@ -626,9 +626,9 @@ func (s *ChatSessionService) asyncChatSolo(dialog *model.Chat, session *model.Ch
|
||||
var bundle *ModelBundle
|
||||
var err error
|
||||
if llmType == "image2text" {
|
||||
bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeImage2Text, dialog.LLMID)
|
||||
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeImage2Text, dialog.LLMID)
|
||||
} else {
|
||||
bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID)
|
||||
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -654,7 +654,7 @@ func (s *ChatSessionService) asyncChatSolo(dialog *model.Chat, session *model.Ch
|
||||
}
|
||||
|
||||
// asyncChatSoloStream performs simple streaming chat without RAG
|
||||
func (s *ChatSessionService) asyncChatSoloStream(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) {
|
||||
func (s *ChatSessionService) asyncChatSoloStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) {
|
||||
// Get system prompt
|
||||
systemPrompt := s.buildSystemPrompt(dialog)
|
||||
|
||||
@ -671,9 +671,9 @@ func (s *ChatSessionService) asyncChatSoloStream(dialog *model.Chat, session *mo
|
||||
var bundle *ModelBundle
|
||||
var err error
|
||||
if llmType == "image2text" {
|
||||
bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeImage2Text, dialog.LLMID)
|
||||
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeImage2Text, dialog.LLMID)
|
||||
} else {
|
||||
bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID)
|
||||
bundle, err = NewModelBundle(dialog.TenantID, entity.ModelTypeChat, dialog.LLMID)
|
||||
}
|
||||
if err != nil {
|
||||
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
|
||||
@ -684,7 +684,7 @@ func (s *ChatSessionService) asyncChatSoloStream(dialog *model.Chat, session *mo
|
||||
history := s.convertToHistory(processedMessages)
|
||||
|
||||
// Get chat model
|
||||
chatModel, ok := bundle.GetModel().(model.ChatModel)
|
||||
chatModel, ok := bundle.GetModel().(entity.ChatModel)
|
||||
if !ok {
|
||||
resultChan <- s.structureAnswer(session, "**ERROR**: model is not a chat model", messageID, session.ID, reference)
|
||||
return
|
||||
@ -709,7 +709,7 @@ func (s *ChatSessionService) asyncChatSoloStream(dialog *model.Chat, session *mo
|
||||
}
|
||||
|
||||
// buildSystemPrompt builds the system prompt from dialog configuration
|
||||
func (s *ChatSessionService) buildSystemPrompt(dialog *model.Chat) string {
|
||||
func (s *ChatSessionService) buildSystemPrompt(dialog *entity.Chat) string {
|
||||
if dialog.PromptConfig == nil {
|
||||
return ""
|
||||
}
|
||||
@ -719,7 +719,7 @@ func (s *ChatSessionService) buildSystemPrompt(dialog *model.Chat) string {
|
||||
}
|
||||
|
||||
// processMessages processes messages and handles attachments
|
||||
func (s *ChatSessionService) processMessages(messages []map[string]interface{}, dialog *model.Chat) []map[string]interface{} {
|
||||
func (s *ChatSessionService) processMessages(messages []map[string]interface{}, dialog *entity.Chat) []map[string]interface{} {
|
||||
// Process each message
|
||||
processed := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
@ -762,7 +762,7 @@ func (s *ChatSessionService) convertToHistory(messages []map[string]interface{})
|
||||
}
|
||||
|
||||
// buildGenConf builds generation config from dialog and request
|
||||
func (s *ChatSessionService) buildGenConf(dialog *model.Chat, config map[string]interface{}) map[string]interface{} {
|
||||
func (s *ChatSessionService) buildGenConf(dialog *entity.Chat, config map[string]interface{}) map[string]interface{} {
|
||||
genConf := make(map[string]interface{})
|
||||
|
||||
// Start with dialog's LLM setting
|
||||
@ -799,7 +799,7 @@ func (s *ChatSessionService) removeReasoningContent(answer string) string {
|
||||
}
|
||||
|
||||
// structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer)
|
||||
func (s *ChatSessionService) structureAnswerWithConv(session *model.ChatSession, ans map[string]interface{}, messageID, conversationID string, reference []interface{}) map[string]interface{} {
|
||||
func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession, ans map[string]interface{}, messageID, conversationID string, reference []interface{}) map[string]interface{} {
|
||||
// Extract reference from answer
|
||||
ref, _ := ans["reference"].(map[string]interface{})
|
||||
if ref == nil {
|
||||
|
||||
@ -19,6 +19,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/server"
|
||||
"strings"
|
||||
|
||||
@ -27,7 +28,7 @@ import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/logger"
|
||||
"ragflow/internal/model"
|
||||
|
||||
"ragflow/internal/service/nlp"
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
@ -76,10 +77,10 @@ type RetrievalTestRequest struct {
|
||||
|
||||
// RetrievalTestResponse retrieval test response
|
||||
type RetrievalTestResponse struct {
|
||||
Chunks []map[string]interface{} `json:"chunks"`
|
||||
DocAggs []map[string]interface{} `json:"doc_aggs"`
|
||||
Chunks []map[string]interface{} `json:"chunks"`
|
||||
DocAggs []map[string]interface{} `json:"doc_aggs"`
|
||||
Labels *[]map[string]interface{} `json:"labels"`
|
||||
Total int64 `json:"total,omitempty"`
|
||||
Total int64 `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// RetrievalTest performs retrieval test
|
||||
@ -130,7 +131,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
|
||||
|
||||
// Check permission for each kb_id
|
||||
var tenantIDs []string
|
||||
var kbRecords []*model.Knowledgebase
|
||||
var kbRecords []*entity.Knowledgebase
|
||||
|
||||
for _, kbID := range kbIDs {
|
||||
found := false
|
||||
@ -651,11 +652,11 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
|
||||
|
||||
// ListChunksRequest request for listing chunks
|
||||
type ListChunksRequest struct {
|
||||
DocID string `json:"doc_id" binding:"required"`
|
||||
Page *int `json:"page,omitempty"`
|
||||
Size *int `json:"size,omitempty"`
|
||||
Keywords string `json:"keywords,omitempty"`
|
||||
AvailableInt *int `json:"available_int,omitempty"`
|
||||
DocID string `json:"doc_id" binding:"required"`
|
||||
Page *int `json:"page,omitempty"`
|
||||
Size *int `json:"size,omitempty"`
|
||||
Keywords string `json:"keywords,omitempty"`
|
||||
AvailableInt *int `json:"available_int,omitempty"`
|
||||
}
|
||||
|
||||
// ListChunksResponse response for listing chunks
|
||||
@ -772,7 +773,7 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR
|
||||
result["image_id"] = ""
|
||||
}
|
||||
case "position_int":
|
||||
result["positions"] = v
|
||||
result["positions"] = v
|
||||
case "id":
|
||||
result["chunk_id"] = v
|
||||
case "content":
|
||||
@ -817,32 +818,32 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR
|
||||
// Build document info (matching Python doc.to_dict())
|
||||
timeFormat := "2006-01-02T15:04:05"
|
||||
docInfo := map[string]interface{}{
|
||||
"id": doc.ID,
|
||||
"thumbnail": doc.Thumbnail,
|
||||
"kb_id": doc.KbID,
|
||||
"parser_id": doc.ParserID,
|
||||
"pipeline_id": doc.PipelineID,
|
||||
"parser_config": doc.ParserConfig,
|
||||
"source_type": doc.SourceType,
|
||||
"type": doc.Type,
|
||||
"created_by": doc.CreatedBy,
|
||||
"name": doc.Name,
|
||||
"location": doc.Location,
|
||||
"size": doc.Size,
|
||||
"token_num": doc.TokenNum,
|
||||
"chunk_num": doc.ChunkNum,
|
||||
"progress": utility.JSONFloat64(doc.Progress),
|
||||
"progress_msg": doc.ProgressMsg,
|
||||
"process_begin_at": utility.FormatTimeToString(doc.ProcessBeginAt, timeFormat),
|
||||
"process_duration": doc.ProcessDuration,
|
||||
"content_hash": doc.ContentHash,
|
||||
"suffix": doc.Suffix,
|
||||
"run": doc.Run,
|
||||
"status": doc.Status,
|
||||
"create_time": doc.CreateTime,
|
||||
"create_date": utility.FormatTimeToString(doc.CreateDate, timeFormat),
|
||||
"update_time": doc.UpdateTime,
|
||||
"update_date": utility.FormatTimeToString(doc.UpdateDate, timeFormat),
|
||||
"id": doc.ID,
|
||||
"thumbnail": doc.Thumbnail,
|
||||
"kb_id": doc.KbID,
|
||||
"parser_id": doc.ParserID,
|
||||
"pipeline_id": doc.PipelineID,
|
||||
"parser_config": doc.ParserConfig,
|
||||
"source_type": doc.SourceType,
|
||||
"type": doc.Type,
|
||||
"created_by": doc.CreatedBy,
|
||||
"name": doc.Name,
|
||||
"location": doc.Location,
|
||||
"size": doc.Size,
|
||||
"token_num": doc.TokenNum,
|
||||
"chunk_num": doc.ChunkNum,
|
||||
"progress": utility.JSONFloat64(doc.Progress),
|
||||
"progress_msg": doc.ProgressMsg,
|
||||
"process_begin_at": utility.FormatTimeToString(doc.ProcessBeginAt, timeFormat),
|
||||
"process_duration": doc.ProcessDuration,
|
||||
"content_hash": doc.ContentHash,
|
||||
"suffix": doc.Suffix,
|
||||
"run": doc.Run,
|
||||
"status": doc.Status,
|
||||
"create_time": doc.CreateTime,
|
||||
"create_date": utility.FormatTimeToString(doc.CreateDate, timeFormat),
|
||||
"update_time": doc.UpdateTime,
|
||||
"update_date": utility.FormatTimeToString(doc.UpdateDate, timeFormat),
|
||||
}
|
||||
|
||||
return &ListChunksResponse{
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -28,7 +29,6 @@ import (
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -196,8 +196,8 @@ func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID stri
|
||||
if name == "" {
|
||||
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
|
||||
}
|
||||
if len(name) > model.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
|
||||
if len(name) > entity.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), entity.DatasetNameLimit)
|
||||
}
|
||||
|
||||
tenant, err := s.tenantDAO.GetByID(tenantID)
|
||||
@ -270,8 +270,8 @@ func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID stri
|
||||
if nameValue == "" {
|
||||
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
|
||||
}
|
||||
if len(nameValue) > model.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(nameValue), model.DatasetNameLimit)
|
||||
if len(nameValue) > entity.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(nameValue), entity.DatasetNameLimit)
|
||||
}
|
||||
name = nameValue
|
||||
case "description":
|
||||
@ -394,8 +394,8 @@ func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID stri
|
||||
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
status := string(model.StatusValid)
|
||||
kb := &model.Knowledgebase{
|
||||
status := string(entity.StatusValid)
|
||||
kb := &entity.Knowledgebase{
|
||||
ID: kbID,
|
||||
Name: s.kbDAO.DuplicateName(name, tenantID),
|
||||
TenantID: tenantID,
|
||||
@ -466,7 +466,7 @@ func (s *DatasetsService) DeleteDatasets(ids []string, deleteAll bool, tenantID
|
||||
}
|
||||
}
|
||||
|
||||
kbs := make([]*model.Knowledgebase, 0, len(normalizedIDs))
|
||||
kbs := make([]*entity.Knowledgebase, 0, len(normalizedIDs))
|
||||
unauthorizedIDs := make([]string, 0)
|
||||
for _, id := range normalizedIDs {
|
||||
kb, err := s.kbDAO.GetByIDAndTenantID(id, tenantID)
|
||||
@ -514,9 +514,9 @@ func (s *DatasetsService) DeleteDatasets(ids []string, deleteAll bool, tenantID
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *DatasetsService) deleteDataset(tenantID string, kb *model.Knowledgebase) error {
|
||||
func (s *DatasetsService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error {
|
||||
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var documents []model.Document
|
||||
var documents []entity.Document
|
||||
if err := tx.Where("kb_id = ?", kb.ID).Find(&documents).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
@ -527,7 +527,7 @@ func (s *DatasetsService) deleteDataset(tenantID string, kb *model.Knowledgebase
|
||||
}
|
||||
|
||||
if len(docIDs) > 0 {
|
||||
var mappings []model.File2Document
|
||||
var mappings []entity.File2Document
|
||||
if err := tx.Where("document_id IN ?", docIDs).Find(&mappings).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
@ -545,29 +545,29 @@ func (s *DatasetsService) deleteDataset(tenantID string, kb *model.Knowledgebase
|
||||
fileIDs = append(fileIDs, *mapping.FileID)
|
||||
}
|
||||
|
||||
if err := tx.Where("doc_id IN ?", docIDs).Delete(&model.Task{}).Error; err != nil {
|
||||
if err := tx.Where("doc_id IN ?", docIDs).Delete(&entity.Task{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
if err := tx.Where("document_id IN ?", docIDs).Delete(&model.File2Document{}).Error; err != nil {
|
||||
if err := tx.Where("document_id IN ?", docIDs).Delete(&entity.File2Document{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
if len(fileIDs) > 0 {
|
||||
if err := tx.Unscoped().Where("id IN ?", fileIDs).Delete(&model.File{}).Error; err != nil {
|
||||
if err := tx.Unscoped().Where("id IN ?", fileIDs).Delete(&entity.File{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
}
|
||||
if err := tx.Where("id IN ?", docIDs).Delete(&model.Document{}).Error; err != nil {
|
||||
if err := tx.Where("id IN ?", docIDs).Delete(&entity.Document{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Unscoped().
|
||||
Where("source_type = ? AND type = ? AND name = ? AND tenant_id = ?", string(model.FileSourceKnowledgebase), "folder", kb.Name, tenantID).
|
||||
Delete(&model.File{}).Error; err != nil {
|
||||
Where("source_type = ? AND type = ? AND name = ? AND tenant_id = ?", string(entity.FileSourceKnowledgebase), "folder", kb.Name, tenantID).
|
||||
Delete(&entity.File{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
|
||||
if err := tx.Where("id = ?", kb.ID).Delete(&model.Knowledgebase{}).Error; err != nil {
|
||||
if err := tx.Where("id = ?", kb.ID).Delete(&entity.Knowledgebase{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
|
||||
@ -682,7 +682,7 @@ func (s *DatasetsService) verifyEmbeddingAvailability(embdID string, tenantID st
|
||||
}
|
||||
if *tenantLLM.LLMName == modelName &&
|
||||
tenantLLM.LLMFactory == provider &&
|
||||
*tenantLLM.ModelType == string(model.ModelTypeEmbedding) {
|
||||
*tenantLLM.ModelType == string(entity.ModelTypeEmbedding) {
|
||||
return true, ""
|
||||
}
|
||||
}
|
||||
@ -722,7 +722,7 @@ func applyAutoMetadataConfig(parserConfig map[string]interface{}, config *AutoMe
|
||||
return parserConfig
|
||||
}
|
||||
|
||||
func datasetListItemToMap(kb *model.KnowledgebaseListItem) map[string]interface{} {
|
||||
func datasetListItemToMap(kb *entity.KnowledgebaseListItem) map[string]interface{} {
|
||||
item := map[string]interface{}{
|
||||
"id": kb.ID,
|
||||
"name": kb.Name,
|
||||
@ -755,7 +755,7 @@ func datasetListItemToMap(kb *model.KnowledgebaseListItem) map[string]interface{
|
||||
return item
|
||||
}
|
||||
|
||||
func datasetToMap(kb *model.Knowledgebase) map[string]interface{} {
|
||||
func datasetToMap(kb *entity.Knowledgebase) map[string]interface{} {
|
||||
item := map[string]interface{}{
|
||||
"id": kb.ID,
|
||||
"tenant_id": kb.TenantID,
|
||||
|
||||
@ -19,13 +19,14 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"regexp"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/model"
|
||||
|
||||
"ragflow/internal/server"
|
||||
)
|
||||
|
||||
@ -95,8 +96,8 @@ type DocumentResponse struct {
|
||||
}
|
||||
|
||||
// CreateDocument create document
|
||||
func (s *DocumentService) CreateDocument(req *CreateDocumentRequest) (*model.Document, error) {
|
||||
document := &model.Document{
|
||||
func (s *DocumentService) CreateDocument(req *CreateDocumentRequest) (*entity.Document, error) {
|
||||
document := &entity.Document{
|
||||
Name: &req.Name,
|
||||
KbID: req.KbID,
|
||||
ParserID: req.ParserID,
|
||||
@ -207,7 +208,7 @@ func (s *DocumentService) GetDocumentsByAuthorID(authorID, page, pageSize int) (
|
||||
}
|
||||
|
||||
// toResponse convert model.Document to DocumentResponse
|
||||
func (s *DocumentService) toResponse(doc *model.Document) *DocumentResponse {
|
||||
func (s *DocumentService) toResponse(doc *entity.Document) *DocumentResponse {
|
||||
createdAt := ""
|
||||
if doc.CreateTime != nil {
|
||||
// Check if timestamp is in milliseconds (13 digits) or seconds (10 digits)
|
||||
@ -405,7 +406,7 @@ func (s *DocumentService) GetMetadataByKBs(kbIDs []string) (map[string]interface
|
||||
|
||||
// valueInfo holds count and order of first appearance
|
||||
type valueInfo struct {
|
||||
count int
|
||||
count int
|
||||
firstOrder int
|
||||
}
|
||||
|
||||
@ -617,7 +618,7 @@ func aggregateMetadata(chunks []map[string]interface{}) map[string]interface{} {
|
||||
}
|
||||
|
||||
result[k] = map[string]interface{}{
|
||||
"type": valueType,
|
||||
"type": valueType,
|
||||
"values": outputValues,
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,12 +18,12 @@ package service
|
||||
|
||||
import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// FileService file service
|
||||
type FileService struct {
|
||||
fileDAO *dao.FileDAO
|
||||
fileDAO *dao.FileDAO
|
||||
file2DocumentDAO *dao.File2DocumentDAO
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ func NewFileService() *FileService {
|
||||
|
||||
// FileInfo file info with additional fields
|
||||
type FileInfo struct {
|
||||
*model.File
|
||||
*entity.File
|
||||
Size int64 `json:"size"`
|
||||
KbsInfo []map[string]interface{} `json:"kbs_info"`
|
||||
HasChildFolder bool `json:"has_child_folder,omitempty"`
|
||||
@ -45,9 +45,9 @@ type FileInfo struct {
|
||||
|
||||
// ListFilesResponse list files response
|
||||
type ListFilesResponse struct {
|
||||
Total int64 `json:"total"`
|
||||
Files []map[string]interface{} `json:"files"`
|
||||
ParentFolder map[string]interface{} `json:"parent_folder"`
|
||||
Total int64 `json:"total"`
|
||||
Files []map[string]interface{} `json:"files"`
|
||||
ParentFolder map[string]interface{} `json:"parent_folder"`
|
||||
}
|
||||
|
||||
// GetRootFolder gets or creates root folder for tenant
|
||||
@ -91,7 +91,7 @@ func (s *FileService) ListFiles(tenantID, pfID string, page, pageSize int, order
|
||||
fileResponses := make([]map[string]interface{}, len(files))
|
||||
for i, file := range files {
|
||||
fileInfo := s.toFileInfo(file)
|
||||
|
||||
|
||||
// If folder, calculate size and check for child folders
|
||||
if file.Type == "folder" {
|
||||
folderSize, err := s.fileDAO.GetFolderSize(file.ID)
|
||||
@ -111,7 +111,7 @@ func (s *FileService) ListFiles(tenantID, pfID string, page, pageSize int, order
|
||||
}
|
||||
fileInfo.KbsInfo = kbsInfo
|
||||
}
|
||||
|
||||
|
||||
fileResponses[i] = s.fileInfoToResponse(fileInfo)
|
||||
}
|
||||
|
||||
@ -123,29 +123,29 @@ func (s *FileService) ListFiles(tenantID, pfID string, page, pageSize int, order
|
||||
}
|
||||
|
||||
// toFileResponse converts file model to response format
|
||||
func (s *FileService) toFileResponse(file *model.File) map[string]interface{} {
|
||||
func (s *FileService) toFileResponse(file *entity.File) map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"id": file.ID,
|
||||
"parent_id": file.ParentID,
|
||||
"tenant_id": file.TenantID,
|
||||
"created_by": file.CreatedBy,
|
||||
"name": file.Name,
|
||||
"size": file.Size,
|
||||
"type": file.Type,
|
||||
"id": file.ID,
|
||||
"parent_id": file.ParentID,
|
||||
"tenant_id": file.TenantID,
|
||||
"created_by": file.CreatedBy,
|
||||
"name": file.Name,
|
||||
"size": file.Size,
|
||||
"type": file.Type,
|
||||
"create_time": file.CreateTime,
|
||||
"update_time": file.UpdateTime,
|
||||
}
|
||||
|
||||
|
||||
if file.Location != nil {
|
||||
result["location"] = *file.Location
|
||||
}
|
||||
result["source_type"] = file.SourceType
|
||||
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// toFileInfo converts file model to FileInfo
|
||||
func (s *FileService) toFileInfo(file *model.File) *FileInfo {
|
||||
func (s *FileService) toFileInfo(file *entity.File) *FileInfo {
|
||||
return &FileInfo{
|
||||
File: file,
|
||||
Size: file.Size,
|
||||
|
||||
@ -23,7 +23,8 @@ import (
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"ragflow/internal/utility"
|
||||
"strings"
|
||||
"time"
|
||||
@ -49,7 +50,7 @@ func NewKnowledgebaseService() *KnowledgebaseService {
|
||||
userDAO: dao.NewUserDAO(),
|
||||
tenantDAO: dao.NewTenantDAO(),
|
||||
connectorDAO: dao.NewConnectorDAO(),
|
||||
docEngine: engine.Get(),
|
||||
docEngine: engine.Get(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,8 +123,8 @@ func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (
|
||||
}
|
||||
|
||||
// Check name length (using UTF-8 byte length like Python)
|
||||
if len(name) > model.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
|
||||
if len(name) > entity.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), entity.DatasetNameLimit)
|
||||
}
|
||||
|
||||
// Verify tenant exists
|
||||
@ -151,7 +152,7 @@ func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (
|
||||
// Create knowledge base model
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
kb := &model.Knowledgebase{
|
||||
kb := &entity.Knowledgebase{
|
||||
ID: kbID,
|
||||
Name: duplicateName,
|
||||
TenantID: tenantID,
|
||||
@ -165,7 +166,7 @@ func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (
|
||||
kb.UpdateTime = &now
|
||||
kb.CreateDate = &nowDate
|
||||
kb.UpdateDate = &nowDate
|
||||
status := string(model.StatusValid)
|
||||
status := string(entity.StatusValid)
|
||||
kb.Status = &status
|
||||
|
||||
// Set optional fields
|
||||
@ -270,8 +271,8 @@ func (s *KnowledgebaseService) UpdateKB(req *UpdateKBRequest, userID string) (ma
|
||||
}
|
||||
|
||||
// Check name length
|
||||
if len(name) > model.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit)
|
||||
if len(name) > entity.DatasetNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), entity.DatasetNameLimit)
|
||||
}
|
||||
|
||||
// Check authorization
|
||||
@ -377,7 +378,7 @@ func (s *KnowledgebaseService) UpdateMetadataSetting(req *UpdateMetadataSettingR
|
||||
|
||||
// GetDetail retrieves detailed information about a knowledge base
|
||||
// This matches the Python kb_detail endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*model.KnowledgebaseDetail, common.ErrorCode, error) {
|
||||
func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*entity.KnowledgebaseDetail, common.ErrorCode, error) {
|
||||
// Check authorization
|
||||
if !s.kbDAO.Accessible(kbID, userID) {
|
||||
return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation")
|
||||
@ -398,7 +399,7 @@ func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*model.Knowledgeb
|
||||
// ListKbs lists knowledge bases with pagination and filtering
|
||||
// This matches the Python list endpoint in kb_app.py
|
||||
func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, common.ErrorCode, error) {
|
||||
var kbs []*model.KnowledgebaseListItem
|
||||
var kbs []*entity.KnowledgebaseListItem
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
@ -475,7 +476,7 @@ func (s *KnowledgebaseService) Accessible(kbID, userID string) bool {
|
||||
}
|
||||
|
||||
// GetByID retrieves a knowledge base by ID
|
||||
func (s *KnowledgebaseService) GetByID(kbID string) (*model.Knowledgebase, error) {
|
||||
func (s *KnowledgebaseService) GetByID(kbID string) (*entity.Knowledgebase, error) {
|
||||
return s.kbDAO.GetByID(kbID)
|
||||
}
|
||||
|
||||
@ -551,13 +552,13 @@ func GenerateUUID() string {
|
||||
}
|
||||
|
||||
// GetUserByToken gets user by authorization token
|
||||
func (s *KnowledgebaseService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) {
|
||||
func (s *KnowledgebaseService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
||||
userService := NewUserService()
|
||||
return userService.GetUserByToken(authorization)
|
||||
}
|
||||
|
||||
// GetUserByID gets user by ID
|
||||
func (s *KnowledgebaseService) GetUserByID(id string) (*model.User, error) {
|
||||
func (s *KnowledgebaseService) GetUserByID(id string) (*entity.User, error) {
|
||||
return s.userDAO.GetByAccessToken(id)
|
||||
}
|
||||
|
||||
@ -572,7 +573,7 @@ func (s *KnowledgebaseService) GetConnectorsByTenantID(tenantID string) ([]*dao.
|
||||
}
|
||||
|
||||
// GetKBList retrieves knowledge bases with ID and name filtering
|
||||
func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, common.ErrorCode, error) {
|
||||
func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, id, name string) ([]*entity.Knowledgebase, int64, common.ErrorCode, error) {
|
||||
kbs, total, err := s.kbDAO.GetList(tenantIDs, userID, page, pageSize, orderby, desc, id, name)
|
||||
if err != nil {
|
||||
return nil, 0, common.CodeServerError, err
|
||||
@ -581,12 +582,12 @@ func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page
|
||||
}
|
||||
|
||||
// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID
|
||||
func (s *KnowledgebaseService) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) {
|
||||
func (s *KnowledgebaseService) GetKBByIDAndUserID(kbID, userID string) ([]*entity.Knowledgebase, error) {
|
||||
return s.kbDAO.GetKBByIDAndUserID(kbID, userID)
|
||||
}
|
||||
|
||||
// GetKBByNameAndUserID retrieves a knowledge base by name and user ID
|
||||
func (s *KnowledgebaseService) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) {
|
||||
func (s *KnowledgebaseService) GetKBByNameAndUserID(kbName, userID string) ([]*entity.Knowledgebase, error) {
|
||||
return s.kbDAO.GetKBByNameAndUserID(kbName, userID)
|
||||
}
|
||||
|
||||
|
||||
@ -18,11 +18,11 @@ package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
var DB = dao.DB
|
||||
@ -381,13 +381,13 @@ func (s *LLMService) SetAPIKey(tenantID string, req *SetAPIKeyRequest) (*SetAPIK
|
||||
"api_base": baseURL,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
DB.Model(&model.TenantLLM{}).
|
||||
DB.Model(&entity.TenantLLM{}).
|
||||
Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, llm.LLMName).
|
||||
Updates(updates)
|
||||
} else {
|
||||
modelType := llm.ModelType
|
||||
llmName := llm.LLMName
|
||||
tenantLLM := &model.TenantLLM{
|
||||
tenantLLM := &entity.TenantLLM{
|
||||
TenantID: tenantID,
|
||||
LLMFactory: factory,
|
||||
ModelType: &modelType,
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"ragflow/internal/entity"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -28,7 +29,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -362,7 +362,7 @@ type UpdateMemoryRequest struct {
|
||||
// CreateMemoryResponse defines the response structure for memory operations
|
||||
// Uses struct embedding to extend Memory struct with API-specific fields
|
||||
type CreateMemoryResponse struct {
|
||||
model.Memory
|
||||
entity.Memory
|
||||
OwnerName *string `json:"owner_name,omitempty"`
|
||||
MemoryType []string `json:"memory_type"`
|
||||
}
|
||||
@ -454,7 +454,7 @@ func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest)
|
||||
newID = newID[:32]
|
||||
}
|
||||
|
||||
memory := &model.Memory{
|
||||
memory := &entity.Memory{
|
||||
ID: newID,
|
||||
Name: memoryName,
|
||||
TenantID: tenantID,
|
||||
@ -845,7 +845,7 @@ func isList(v interface{}) bool {
|
||||
// Example:
|
||||
//
|
||||
// resp := formatRetDataFromMemory(memoryModel)
|
||||
func formatRetDataFromMemory(memory *model.Memory) *CreateMemoryResponse {
|
||||
func formatRetDataFromMemory(memory *entity.Memory) *CreateMemoryResponse {
|
||||
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
||||
|
||||
resp := &CreateMemoryResponse{
|
||||
@ -881,7 +881,7 @@ func formatDateToString(t int64) *string {
|
||||
// Example:
|
||||
//
|
||||
// resp := formatRetDataFromMemoryListItem(memoryItem)
|
||||
func formatRetDataFromMemoryListItem(memory *model.MemoryListItem) *CreateMemoryResponse {
|
||||
func formatRetDataFromMemoryListItem(memory *entity.MemoryListItem) *CreateMemoryResponse {
|
||||
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
||||
resp := &CreateMemoryResponse{
|
||||
Memory: memory.Memory,
|
||||
|
||||
@ -19,22 +19,21 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// ModelBundle provides a unified interface for various model operations
|
||||
// Similar to Python's LLMBundle but with a more generic name
|
||||
type ModelBundle struct {
|
||||
tenantID string
|
||||
modelType model.ModelType
|
||||
modelType entity.ModelType
|
||||
modelName string
|
||||
model interface{} // underlying model instance
|
||||
}
|
||||
|
||||
// NewModelBundle creates a new ModelBundle for the given tenant and model type
|
||||
// If modelName is empty, uses the default model for the tenant and type
|
||||
func NewModelBundle(tenantID string, modelType model.ModelType, modelName ...string) (*ModelBundle, error) {
|
||||
func NewModelBundle(tenantID string, modelType entity.ModelType, modelName ...string) (*ModelBundle, error) {
|
||||
bundle := &ModelBundle{
|
||||
tenantID: tenantID,
|
||||
modelType: modelType,
|
||||
@ -48,19 +47,19 @@ func NewModelBundle(tenantID string, modelType model.ModelType, modelName ...str
|
||||
// Get model instance based on type
|
||||
provider := NewModelProvider()
|
||||
switch modelType {
|
||||
case model.ModelTypeEmbedding:
|
||||
case entity.ModelTypeEmbedding:
|
||||
embeddingModel, err := provider.GetEmbeddingModel(context.Background(), tenantID, bundle.modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
bundle.model = embeddingModel
|
||||
case model.ModelTypeChat:
|
||||
case entity.ModelTypeChat:
|
||||
chatModel, err := provider.GetChatModel(context.Background(), tenantID, bundle.modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get chat model: %w", err)
|
||||
}
|
||||
bundle.model = chatModel
|
||||
case model.ModelTypeRerank:
|
||||
case entity.ModelTypeRerank:
|
||||
rerankModel, err := provider.GetRerankModel(context.Background(), tenantID, bundle.modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model: %w", err)
|
||||
@ -76,11 +75,11 @@ func NewModelBundle(tenantID string, modelType model.ModelType, modelName ...str
|
||||
// Encode encodes a list of texts into embeddings
|
||||
// Returns embeddings and token count (for compatibility with Python interface)
|
||||
func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
|
||||
if b.modelType != model.ModelTypeEmbedding {
|
||||
if b.modelType != entity.ModelTypeEmbedding {
|
||||
return nil, 0, fmt.Errorf("model type %s does not support encode", b.modelType)
|
||||
}
|
||||
|
||||
embeddingModel, ok := b.model.(model.EmbeddingModel)
|
||||
embeddingModel, ok := b.model.(entity.EmbeddingModel)
|
||||
if !ok {
|
||||
return nil, 0, fmt.Errorf("model is not an embedding model")
|
||||
}
|
||||
@ -103,11 +102,11 @@ func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) {
|
||||
// EncodeQuery encodes a single query string into embedding
|
||||
// Returns embedding and token count
|
||||
func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
|
||||
if b.modelType != model.ModelTypeEmbedding {
|
||||
if b.modelType != entity.ModelTypeEmbedding {
|
||||
return nil, 0, fmt.Errorf("model type %s does not support encode query", b.modelType)
|
||||
}
|
||||
|
||||
embeddingModel, ok := b.model.(model.EmbeddingModel)
|
||||
embeddingModel, ok := b.model.(entity.EmbeddingModel)
|
||||
if !ok {
|
||||
return nil, 0, fmt.Errorf("model is not an embedding model")
|
||||
}
|
||||
@ -125,11 +124,11 @@ func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) {
|
||||
|
||||
// Chat sends a chat message and returns response
|
||||
func (b *ModelBundle) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, int64, error) {
|
||||
if b.modelType != model.ModelTypeChat {
|
||||
if b.modelType != entity.ModelTypeChat {
|
||||
return "", 0, fmt.Errorf("model type %s does not support chat", b.modelType)
|
||||
}
|
||||
|
||||
chatModel, ok := b.model.(model.ChatModel)
|
||||
chatModel, ok := b.model.(entity.ChatModel)
|
||||
if !ok {
|
||||
return "", 0, fmt.Errorf("model is not a chat model")
|
||||
}
|
||||
@ -147,11 +146,11 @@ func (b *ModelBundle) Chat(system string, history []map[string]string, genConf m
|
||||
|
||||
// Similarity calculates similarity between query and texts
|
||||
func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64, error) {
|
||||
if b.modelType != model.ModelTypeRerank {
|
||||
if b.modelType != entity.ModelTypeRerank {
|
||||
return nil, 0, fmt.Errorf("model type %s does not support similarity", b.modelType)
|
||||
}
|
||||
|
||||
rerankModel, ok := b.model.(model.RerankModel)
|
||||
rerankModel, ok := b.model.(entity.RerankModel)
|
||||
if !ok {
|
||||
return nil, 0, fmt.Errorf("model is not a rerank model")
|
||||
}
|
||||
|
||||
@ -21,21 +21,21 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/service/models"
|
||||
)
|
||||
|
||||
// ModelProvider provides model instances based on tenant and model type
|
||||
type ModelProvider interface {
|
||||
// GetEmbeddingModel returns an embedding model for the given tenant
|
||||
GetEmbeddingModel(ctx context.Context, tenantID string, modelName string) (model.EmbeddingModel, error)
|
||||
GetEmbeddingModel(ctx context.Context, tenantID string, modelName string) (entity.EmbeddingModel, error)
|
||||
// GetChatModel returns a chat model for the given tenant
|
||||
GetChatModel(ctx context.Context, tenantID string, modelName string) (model.ChatModel, error)
|
||||
GetChatModel(ctx context.Context, tenantID string, modelName string) (entity.ChatModel, error)
|
||||
// GetRerankModel returns a rerank model for the given tenant
|
||||
GetRerankModel(ctx context.Context, tenantID string, modelName string) (model.RerankModel, error)
|
||||
GetRerankModel(ctx context.Context, tenantID string, modelName string) (entity.RerankModel, error)
|
||||
}
|
||||
|
||||
// ModelProviderImpl implements ModelProvider
|
||||
@ -66,7 +66,7 @@ func parseModelName(compositeName string) (modelName, provider string, err error
|
||||
}
|
||||
|
||||
// GetEmbeddingModel returns an embedding model for the given tenant
|
||||
func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID string, compositeModelName string) (model.EmbeddingModel, error) {
|
||||
func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID string, compositeModelName string) (entity.EmbeddingModel, error) {
|
||||
// Parse composite model name to extract model name and provider
|
||||
modelName, provider, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
@ -95,7 +95,7 @@ func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID stri
|
||||
}
|
||||
|
||||
// GetChatModel returns a chat model for the given tenant
|
||||
func (p *ModelProviderImpl) GetChatModel(ctx context.Context, tenantID string, compositeModelName string) (model.ChatModel, error) {
|
||||
func (p *ModelProviderImpl) GetChatModel(ctx context.Context, tenantID string, compositeModelName string) (entity.ChatModel, error) {
|
||||
// Parse composite model name to extract model name and provider
|
||||
_, _, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
@ -106,7 +106,7 @@ func (p *ModelProviderImpl) GetChatModel(ctx context.Context, tenantID string, c
|
||||
}
|
||||
|
||||
// GetRerankModel returns a rerank model for the given tenant
|
||||
func (p *ModelProviderImpl) GetRerankModel(ctx context.Context, tenantID string, compositeModelName string) (model.RerankModel, error) {
|
||||
func (p *ModelProviderImpl) GetRerankModel(ctx context.Context, tenantID string, compositeModelName string) (entity.RerankModel, error) {
|
||||
// Parse composite model name to extract model name and provider
|
||||
_, _, err := parseModelName(compositeModelName)
|
||||
if err != nil {
|
||||
|
||||
@ -18,11 +18,11 @@ package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("DeepSeek", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("DeepSeek", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -19,12 +19,13 @@ package models
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"sync"
|
||||
)
|
||||
|
||||
// EmbeddingModelFactory creates an EmbeddingModel instance
|
||||
type EmbeddingModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel
|
||||
type EmbeddingModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel
|
||||
|
||||
var (
|
||||
embeddingModelFactories = make(map[string]EmbeddingModelFactory)
|
||||
@ -49,7 +50,7 @@ func GetEmbeddingModelFactory(providerName string) EmbeddingModelFactory {
|
||||
|
||||
// CreateEmbeddingModel creates an EmbeddingModel instance for the given provider.
|
||||
// Returns error if provider not registered.
|
||||
func CreateEmbeddingModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (model.EmbeddingModel, error) {
|
||||
func CreateEmbeddingModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (entity.EmbeddingModel, error) {
|
||||
factory := GetEmbeddingModelFactory(providerName)
|
||||
if factory == nil {
|
||||
return nil, fmt.Errorf("no embedding model factory registered for provider %s", providerName)
|
||||
|
||||
@ -21,7 +21,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -115,7 +116,7 @@ func (m *giteeEmbeddingModel) EncodeQuery(query string) ([]float64, error) {
|
||||
|
||||
// init registers the GiteeAI embedding model factory
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("GiteeAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("GiteeAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &giteeEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -18,11 +18,11 @@ package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("Moonshot", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("Moonshot", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -18,11 +18,11 @@ package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("OpenAI-API-Compatible", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("OpenAI-API-Compatible", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -21,7 +21,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -112,7 +113,7 @@ func (m *openAIEmbeddingModel) EncodeQuery(query string) ([]float64, error) {
|
||||
|
||||
// init registers the OpenAI embedding model factory
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("OpenAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("OpenAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -21,7 +21,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -112,7 +113,7 @@ func (m *siliconflowEmbeddingModel) EncodeQuery(query string) ([]float64, error)
|
||||
|
||||
// init registers the SILICONFLOW embedding model factory
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &siliconflowEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -18,11 +18,11 @@ package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterEmbeddingModelFactory("ZHIPU-AI", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel {
|
||||
RegisterEmbeddingModelFactory("ZHIPU-AI", func(apiKey, apiBase, modelName string, httpClient *http.Client) entity.EmbeddingModel {
|
||||
return &openAIEmbeddingModel{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
|
||||
@ -18,7 +18,7 @@ package service
|
||||
|
||||
import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// SearchService search service
|
||||
@ -37,7 +37,7 @@ func NewSearchService() *SearchService {
|
||||
|
||||
// SearchWithTenantInfo search with tenant info
|
||||
type SearchWithTenantInfo struct {
|
||||
*model.Search
|
||||
*entity.Search
|
||||
Nickname string `json:"nickname"`
|
||||
TenantAvatar string `json:"tenant_avatar,omitempty"`
|
||||
}
|
||||
@ -55,7 +55,7 @@ type ListSearchAppsResponse struct {
|
||||
|
||||
// ListSearchApps list search apps with advanced filtering (equivalent to list_search_app)
|
||||
func (s *SearchService) ListSearchApps(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListSearchAppsResponse, error) {
|
||||
var searches []*model.Search
|
||||
var searches []*entity.Search
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
@ -88,7 +88,7 @@ func (s *SearchService) ListSearchApps(userID string, keywords string, page, pag
|
||||
}
|
||||
searches = searches[start:end]
|
||||
} else {
|
||||
searches = []*model.Search{}
|
||||
searches = []*entity.Search{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -106,7 +106,7 @@ func (s *SearchService) ListSearchApps(userID string, keywords string, page, pag
|
||||
}
|
||||
|
||||
// toSearchAppResponse converts search model to response format
|
||||
func (s *SearchService) toSearchAppResponse(search *model.Search) map[string]interface{} {
|
||||
func (s *SearchService) toSearchAppResponse(search *entity.Search) map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"id": search.ID,
|
||||
"tenant_id": search.TenantID,
|
||||
|
||||
@ -17,14 +17,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"context"
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/engine"
|
||||
)
|
||||
|
||||
@ -133,10 +133,10 @@ func NewTenantLLMService() *TenantLLMService {
|
||||
* // Get API key for model without factory
|
||||
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4")
|
||||
*/
|
||||
func (s *TenantLLMService) GetAPIKey(tenantID, modelName string) (*model.TenantLLM, error) {
|
||||
func (s *TenantLLMService) GetAPIKey(tenantID, modelName string) (*entity.TenantLLM, error) {
|
||||
modelName, factory := s.SplitModelNameAndFactory(modelName)
|
||||
|
||||
var tenantLLM *model.TenantLLM
|
||||
var tenantLLM *entity.TenantLLM
|
||||
var err error
|
||||
|
||||
if factory == "" {
|
||||
|
||||
@ -30,6 +30,7 @@ import (
|
||||
"hash"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/server"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@ -40,7 +41,7 @@ import (
|
||||
"golang.org/x/crypto/scrypt"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
@ -101,7 +102,7 @@ type UserResponse struct {
|
||||
}
|
||||
|
||||
// Register user registration
|
||||
func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorCode, error) {
|
||||
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
|
||||
cfg := server.GetConfig()
|
||||
if cfg.RegisterEnabled == 0 {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
|
||||
@ -134,7 +135,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
loginChannel := "password"
|
||||
isSuperuser := false
|
||||
|
||||
user := &model.User{
|
||||
user := &entity.User{
|
||||
ID: userID,
|
||||
AccessToken: &accessToken,
|
||||
Email: req.Email,
|
||||
@ -179,7 +180,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
rerankID = ""
|
||||
}
|
||||
|
||||
tenant := &model.Tenant{
|
||||
tenant := &entity.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
LLMID: llmID,
|
||||
@ -196,7 +197,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
tenant.UpdateDate = &now_date
|
||||
|
||||
userTenantID := utility.GenerateToken()
|
||||
userTenant := &model.UserTenant{
|
||||
userTenant := &entity.UserTenant{
|
||||
ID: userTenantID,
|
||||
UserID: userID,
|
||||
TenantID: userID,
|
||||
@ -210,7 +211,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
userTenant.UpdateDate = &now_date
|
||||
|
||||
fileID := utility.GenerateToken()
|
||||
rootFile := &model.File{
|
||||
rootFile := &entity.File{
|
||||
ID: fileID,
|
||||
ParentID: fileID,
|
||||
TenantID: userID,
|
||||
@ -272,7 +273,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
}
|
||||
|
||||
// Login user login
|
||||
func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, error) {
|
||||
func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, error) {
|
||||
// Get user by email (using username field as email)
|
||||
user, err := s.userDAO.GetByEmail(req.Username)
|
||||
if err != nil {
|
||||
@ -315,7 +316,7 @@ func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, e
|
||||
// - CodeAuthenticationError (109): Email not registered or password mismatch
|
||||
// - CodeServerError (500): Password decryption failure
|
||||
// - CodeForbidden (403): Account disabled
|
||||
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) {
|
||||
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*entity.User, common.ErrorCode, error) {
|
||||
user, err := s.userDAO.GetByEmail(req.Email)
|
||||
if err != nil {
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("Email: %s is not registered!", req.Email)
|
||||
@ -639,7 +640,7 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error)
|
||||
// GetUserByToken gets user by authorization header
|
||||
// The token parameter is the authorization header value, which needs to be decrypted
|
||||
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
|
||||
func (s *UserService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) {
|
||||
func (s *UserService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
||||
// Get secret key from config
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
@ -666,12 +667,12 @@ func (s *UserService) GetUserByToken(authorization string) (*model.User, common.
|
||||
}
|
||||
|
||||
// UpdateUserAccessToken updates user's access token
|
||||
func (s *UserService) UpdateUserAccessToken(user *model.User, token string) error {
|
||||
func (s *UserService) UpdateUserAccessToken(user *entity.User, token string) error {
|
||||
return s.userDAO.UpdateAccessToken(user, token)
|
||||
}
|
||||
|
||||
// Logout invalidates user's access token
|
||||
func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) {
|
||||
func (s *UserService) Logout(user *entity.User) (common.ErrorCode, error) {
|
||||
// Invalidate token by setting it to an invalid value
|
||||
// Similar to Python implementation: "INVALID_" + secrets.token_hex(16)
|
||||
invalidToken := "INVALID_" + utility.GenerateToken()
|
||||
@ -683,7 +684,7 @@ func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) {
|
||||
}
|
||||
|
||||
// GetUserProfile returns user profile information
|
||||
func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} {
|
||||
func (s *UserService) GetUserProfile(user *entity.User) map[string]interface{} {
|
||||
// Format create time and date (from database fields)
|
||||
createTime := user.CreateTime
|
||||
createDate := ""
|
||||
@ -788,7 +789,7 @@ func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} {
|
||||
}
|
||||
|
||||
// UpdateUserSettings updates user settings
|
||||
func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) (common.ErrorCode, error) {
|
||||
func (s *UserService) UpdateUserSettings(user *entity.User, req *UpdateSettingsRequest) (common.ErrorCode, error) {
|
||||
// Update fields if provided
|
||||
if req.Nickname != nil {
|
||||
user.Nickname = *req.Nickname
|
||||
@ -818,7 +819,7 @@ func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRe
|
||||
}
|
||||
|
||||
// ChangePassword changes user password
|
||||
func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) (common.ErrorCode, error) {
|
||||
func (s *UserService) ChangePassword(user *entity.User, req *ChangePasswordRequest) (common.ErrorCode, error) {
|
||||
// If password is provided, verify current password
|
||||
if req.Password != nil {
|
||||
if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) {
|
||||
@ -1004,7 +1005,7 @@ func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*Use
|
||||
* Returns:
|
||||
* - *UserTenantRelation: the converted UserTenantRelation
|
||||
*/
|
||||
func convertToUserTenantRelation(userTenant *model.UserTenant) *UserTenantRelation {
|
||||
func convertToUserTenantRelation(userTenant *entity.UserTenant) *UserTenantRelation {
|
||||
return &UserTenantRelation{
|
||||
ID: userTenant.ID,
|
||||
UserID: userTenant.UserID,
|
||||
@ -1016,7 +1017,7 @@ func convertToUserTenantRelation(userTenant *model.UserTenant) *UserTenantRelati
|
||||
// GetUserByAPIToken gets user by access key from Authorization header
|
||||
// This is used for API token authentication
|
||||
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
|
||||
func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, common.ErrorCode, error) {
|
||||
func (s *UserService) GetUserByAPIToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
||||
if authorization == "" {
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("authorization header is empty")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user