diff --git a/api/apps/restful_apis/system_api.py b/api/apps/restful_apis/system_api.py index bae1f0eee..55c34c25a 100644 --- a/api/apps/restful_apis/system_api.py +++ b/api/apps/restful_apis/system_api.py @@ -39,7 +39,6 @@ async def ping(): return "pong", 200 @manager.route("/system/version", methods=["GET"]) # noqa: F821 -@login_required def version(): """ Get the current version of the application. diff --git a/common/settings.py b/common/settings.py index 43135fa00..49693b937 100644 --- a/common/settings.py +++ b/common/settings.py @@ -174,7 +174,8 @@ def _get_or_create_secret_key(): generated_key = secrets.token_hex(32) secret_key = REDIS_CONN.get_or_create_secret_key("ragflow:system:secret_key", generated_key) - logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.") + if generated_key == secret_key: + logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.") return secret_key class StorageFactory: diff --git a/internal/admin/handler.go b/internal/admin/handler.go index e083c825b..ee823d5df 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "net/http" + "ragflow/internal/cache" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/server" @@ -153,8 +154,15 @@ func (h *Handler) Login(c *gin.Context) { return } - variables := server.GetVariables() - secretKey := variables.SecretKey + secretKey, err := server.GetSecretKey(cache.Get()) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": fmt.Sprintf("Failed to get secret key: %s", err.Error()), + }) + return + } + authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/internal/dao/chat_session.go b/internal/dao/chat_session.go index 758a9c596..86aee8766 100644 --- a/internal/dao/chat_session.go +++ b/internal/dao/chat_session.go @@ -53,20 +53,20 @@ func (dao *ChatSessionDAO) DeleteByID(id string) error { return DB.Where("id = ?", id).Delete(&entity.ChatSession{}).Error } -// ListByDialogID lists chat sessions by dialog ID -func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*entity.ChatSession, error) { +// ListByChatID lists chat sessions by chat ID +func (dao *ChatSessionDAO) ListByChatID(chatID string) ([]*entity.ChatSession, error) { var convs []*entity.ChatSession - err := DB.Where("dialog_id = ?", dialogID). + err := DB.Where("dialog_id = ?", chatID). Order("create_time DESC"). Find(&convs).Error return convs, err } // CheckDialogExists checks if a dialog exists with given tenant_id and dialog_id -func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, error) { +func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, error) { var count int64 err := DB.Model(&entity.Chat{}). - Where("tenant_id = ? AND id = ? AND status = ?", tenantID, dialogID, "1"). + Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, "1"). Count(&count).Error if err != nil { return false, err @@ -75,9 +75,9 @@ func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, e } // GetDialogByID gets dialog by ID -func (dao *ChatSessionDAO) GetDialogByID(dialogID string) (*entity.Chat, error) { +func (dao *ChatSessionDAO) GetDialogByID(chatID string) (*entity.Chat, error) { var dialog entity.Chat - err := DB.Where("id = ? AND status = ?", dialogID, "1").First(&dialog).Error + err := DB.Where("id = ? AND status = ?", chatID, "1").First(&dialog).Error if err != nil { return nil, err } diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index 897e62f18..c3489d70e 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -148,9 +148,9 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { // @Tags chat_session // @Accept json // @Produce json -// @Param dialog_id query string true "dialog ID" +// @Param chat_id query string true "chat ID" // @Success 200 {object} service.ListChatSessionsResponse -// @Router /v1/conversation/list [get] +// @Router /api/v1/chats//sessions [get] func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { @@ -159,18 +159,18 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { } userID := user.ID - // Get dialog_id from query parameter - dialogID := c.Query("dialog_id") - if dialogID == "" { + // Get chat_id from query parameter + chatID := c.Param("chat_id") + if chatID == "" { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, - "message": "dialog_id is required", + "message": "chat_id is required", }) return } // Call service to list chat sessions - result, err := h.chatSessionService.ListChatSessions(userID, dialogID) + result, err := h.chatSessionService.ListChatSessions(userID, chatID) if err != nil { // Check if it's an authorization error if err.Error() == "Only owner of dialog authorized for this operation" { diff --git a/internal/handler/user.go b/internal/handler/user.go index 645683cc2..aecb359f8 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -19,6 +19,7 @@ package handler import ( "fmt" "net/http" + "ragflow/internal/cache" "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/server/local" @@ -72,8 +73,15 @@ func (h *UserHandler) Register(c *gin.Context) { return } - variables := server.GetVariables() - secretKey := variables.SecretKey + secretKey, err := server.GetSecretKey(cache.Get()) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": fmt.Sprintf("Failed to get secret key: %s", err.Error()), + "data": false, + }) + return + } authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -129,8 +137,15 @@ func (h *UserHandler) Login(c *gin.Context) { } // Sign the access_token using itsdangerous (compatible with Python) - variables := server.GetVariables() - secretKey := variables.SecretKey + secretKey, err := server.GetSecretKey(cache.Get()) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": fmt.Sprintf("Failed to get secret key: %s", err.Error()), + "data": false, + }) + return + } authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -197,8 +212,15 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) { return } - variables := server.GetVariables() - secretKey := variables.SecretKey + secretKey, err := server.GetSecretKey(cache.Get()) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": fmt.Sprintf("Failed to get secret key: %s", err.Error()), + "data": false, + }) + return + } authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/internal/router/router.go b/internal/router/router.go index 46369ac09..231634536 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -90,15 +90,15 @@ func (r *Router) Setup(engine *gin.Engine) { // System endpoints engine.GET("/v1/system/ping", r.systemHandler.Ping) - engine.GET("/v1/system/config", r.systemHandler.GetConfig) + engine.GET("/api/v1/system/config", r.systemHandler.GetConfig) engine.GET("/v1/system/configs", r.systemHandler.GetConfigs) - engine.GET("/v1/system/version", r.systemHandler.GetVersion) + engine.GET("/api/v1/system/version", r.systemHandler.GetVersion) engine.POST("/v1/user/register", r.userHandler.Register) // User login channels endpoint - engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels) + engine.GET("/api/v1/auth/login/channels", r.userHandler.GetLoginChannels) // User login by email endpoint - engine.POST("/v1/user/login", r.userHandler.LoginByEmail) + engine.POST("/api/v1/auth/login", r.userHandler.LoginByEmail) // User logout endpoint engine.GET("/v1/user/logout", r.userHandler.Logout) @@ -123,14 +123,25 @@ func (r *Router) Setup(engine *gin.Engine) { // API v1 route group v1 := authorized.Group("/api/v1") { - // User routes - //users := v1.Group("/users") - //{ - // users.POST("/register", r.userHandler.Register) - // users.POST("/login", r.userHandler.Login) - // users.GET("", r.userHandler.ListUsers) - // users.GET("/:id", r.userHandler.GetUserByID) - //} + // Auth routes + auth := v1.Group("/auth") + { + // User logout endpoint + auth.GET("/logout", r.userHandler.Logout) + } + + // Users routes + users := v1.Group("/users") + { + users.GET("/me", r.userHandler.Info) + // User settings endpoint + users.PATCH("/me", r.userHandler.Setting) + } + + tenants := v1.Group("/tenants") + { + tenants.GET("", r.tenantHandler.TenantList) + } // Document routes documents := v1.Group("/documents") @@ -142,7 +153,15 @@ func (r *Router) Setup(engine *gin.Engine) { documents.DELETE("/:id", r.documentHandler.DeleteDocument) } - // RESTful dataset routes + // Chat routes + chats := v1.Group("/chats") + { + chats.GET("", r.chatHandler.ListChats) + chats.GET("/:chat_id", r.chatHandler.GetChat) + chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions) + } + + // Dataset routes datasets := v1.Group("/datasets") { datasets.GET("", r.datasetsHandler.ListDatasets) @@ -150,6 +169,26 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.DELETE("", r.datasetsHandler.DeleteDatasets) } + // Search routes + searches := v1.Group("/searches") + { + searches.GET("", r.searchHandler.ListSearches) + searches.POST("", r.searchHandler.CreateSearch) + searches.GET("/:search_id", r.searchHandler.GetSearch) + searches.PUT("/:search_id", r.searchHandler.UpdateSearch) + searches.DELETE("/:search_id", r.searchHandler.DeleteSearch) + } + + file := v1.Group("/files") + { + file.POST("", r.fileHandler.UploadFile) + file.GET("", r.fileHandler.ListFiles) + file.DELETE("", r.fileHandler.DeleteFiles) + file.POST("/move", r.fileHandler.MoveFiles) + file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors) + file.GET("/:id", r.fileHandler.Download) + } + // Author routes authors := v1.Group("/authors") { @@ -167,62 +206,37 @@ func (r *Router) Setup(engine *gin.Engine) { memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages) } - // TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine - // message := v1.Group("/messages") - // { - // message.POST("", r.memoryHandler.AddMessage) - // message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage) - // message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage) - // message.GET("/search", r.memoryHandler.SearchMessage) - // message.GET("", r.memoryHandler.GetMessages) - // message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent) - // } + // TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine + // message := v1.Group("/messages") + // { + // message.POST("", r.memoryHandler.AddMessage) + // message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage) + // message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage) + // message.GET("/search", r.memoryHandler.SearchMessage) + // message.GET("", r.memoryHandler.GetMessages) + // message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent) + // } - // Skill search routes - skills := v1.Group("/skills") - { - // Skill Space management - skills.GET("/spaces", r.skillSearchHandler.ListSpaces) - skills.POST("/spaces", r.skillSearchHandler.CreateSpace) - skills.GET("/spaces/:space_id", r.skillSearchHandler.GetSpace) - skills.PUT("/spaces/:space_id", r.skillSearchHandler.UpdateSpace) - skills.DELETE("/spaces/:space_id", r.skillSearchHandler.DeleteSpace) - skills.GET("/space/by-folder", r.skillSearchHandler.GetSpaceByFolder) - - // Skill search config - skills.GET("/config", r.skillSearchHandler.GetConfig) - skills.POST("/config", r.skillSearchHandler.UpdateConfig) - - // Skill search and indexing - skills.POST("/search", r.skillSearchHandler.Search) - skills.POST("/index", r.skillSearchHandler.IndexSkills) - skills.DELETE("/index", r.skillSearchHandler.DeleteSkillIndex) - skills.POST("/reindex", r.skillSearchHandler.Reindex) - } - - chats := v1.Group("/chats") + // Skill search routes + skills := v1.Group("/skills") { - chats.GET("", r.chatHandler.ListChats) - chats.GET("/:chat_id", r.chatHandler.GetChat) - } + // Skill Space management + skills.GET("/spaces", r.skillSearchHandler.ListSpaces) + skills.POST("/spaces", r.skillSearchHandler.CreateSpace) + skills.GET("/spaces/:space_id", r.skillSearchHandler.GetSpace) + skills.PUT("/spaces/:space_id", r.skillSearchHandler.UpdateSpace) + skills.DELETE("/spaces/:space_id", r.skillSearchHandler.DeleteSpace) + skills.GET("/space/by-folder", r.skillSearchHandler.GetSpaceByFolder) - searches := v1.Group("/searches") - { - searches.GET("", r.searchHandler.ListSearches) - searches.POST("", r.searchHandler.CreateSearch) - searches.GET("/:search_id", r.searchHandler.GetSearch) - searches.PUT("/:search_id", r.searchHandler.UpdateSearch) - searches.DELETE("/:search_id", r.searchHandler.DeleteSearch) - } + // Skill search config + skills.GET("/config", r.skillSearchHandler.GetConfig) + skills.POST("/config", r.skillSearchHandler.UpdateConfig) - file := v1.Group("/files") - { - file.POST("", r.fileHandler.UploadFile) - file.GET("", r.fileHandler.ListFiles) - file.DELETE("", r.fileHandler.DeleteFiles) - file.POST("/move", r.fileHandler.MoveFiles) - file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors) - file.GET("/:id", r.fileHandler.Download) + // Skill search and indexing + skills.POST("/search", r.skillSearchHandler.Search) + skills.POST("/index", r.skillSearchHandler.IndexSkills) + skills.DELETE("/index", r.skillSearchHandler.DeleteSkillIndex) + skills.POST("/reindex", r.skillSearchHandler.Reindex) } // provider pool route group @@ -256,7 +270,6 @@ func (r *Router) Setup(engine *gin.Engine) { system := v1.Group("/system") { - system.GET("/version", r.systemHandler.GetVersion) system.GET("/configs", r.systemHandler.GetConfigs) log := system.Group("/log") { diff --git a/internal/server/config.go b/internal/server/config.go index d0a6ef03d..25f1b4187 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -36,6 +36,7 @@ const DefaultConnectTimeout = 5 * time.Second // Config application configuration type Config struct { Server ServerConfig `mapstructure:"server"` + Authentication AuthenticationConfig `mapstructure:"authentication"` Database DatabaseConfig `mapstructure:"database"` Redis RedisConfig `mapstructure:"redis"` Log LogConfig `mapstructure:"log"` @@ -55,6 +56,11 @@ type AdminConfig struct { Port int `mapstructure:"http_port"` } +type AuthenticationConfig struct { + DisablePasswordLogin bool `mapstructure:"disable_password_login"` + RegisterEnabled bool `mapstructure:"register_enabled"` +} + type DefaultSuperUser struct { Email string `mapstructure:"email"` Password string `mapstructure:"password"` @@ -91,8 +97,9 @@ type OAuthConfig struct { // ServerConfig server configuration type ServerConfig struct { - Mode string `mapstructure:"mode"` // debug, release - Port int `mapstructure:"port"` + Mode string `mapstructure:"mode"` // debug, release + Port int `mapstructure:"port"` + SecretKey *string `mapstructure:"secret_key"` } // DatabaseConfig database configuration @@ -372,6 +379,31 @@ func Init(configPath string) error { } func FromEnvironments() error { + // Secret key + if envVal := os.Getenv("RAGFLOW_SECRET_KEY"); envVal != "" { + globalConfig.Server.SecretKey = &envVal + } + + // Load REGISTER_ENABLED from environment variable (default: true) + if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" { + str := strings.ToLower(envVal) + if str == "true" || str == "1" || str == "yes" { + globalConfig.Authentication.RegisterEnabled = true + } else { + globalConfig.Authentication.RegisterEnabled = false + } + } + + // Load DISABLE_PASSWORD_LOGIN from environment variable (default: false) + if envVal := os.Getenv("DISABLE_PASSWORD_LOGIN"); envVal != "" { + str := strings.ToLower(envVal) + if str == "true" || str == "1" || str == "yes" { + globalConfig.Authentication.DisablePasswordLogin = true + } else { + globalConfig.Authentication.DisablePasswordLogin = false + } + } + // Doc engine docEngine := strings.ToLower(os.Getenv("DOC_ENGINE")) switch docEngine { @@ -535,14 +567,23 @@ func FromConfigFile(configPath string) error { globalConfig.Admin.Port += 2 } - // Load REGISTER_ENABLED from environment variable (default: 1) - registerEnabled := 1 - if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" { - if parsed, err := strconv.Atoi(envVal); err == nil { - registerEnabled = parsed + // authentication section + if globalConfig != nil { + // Try to map from mysql section + globalConfig.Authentication.DisablePasswordLogin = false + globalConfig.Authentication.RegisterEnabled = true + if v.IsSet("authentication") { + authenticationConfig := v.Sub("authentication") + if authenticationConfig != nil { + if authenticationConfig.IsSet("disable_password_login") { + globalConfig.Authentication.DisablePasswordLogin = authenticationConfig.GetBool("disable_password_login") + } + if authenticationConfig.IsSet("enable_register") { + globalConfig.Authentication.RegisterEnabled = authenticationConfig.GetBool("enable_register") + } + } } } - globalConfig.RegisterEnabled = registerEnabled // If we loaded service_conf.yaml, map mysql fields to DatabaseConfig if globalConfig != nil && globalConfig.Database.Host == "" { @@ -573,6 +614,10 @@ func FromConfigFile(configPath string) error { if globalConfig.Server.Mode == "" { globalConfig.Server.Mode = "release" } + secretKey := ragflowConfig.GetString("secret_key") + if secretKey != "" { + globalConfig.Server.SecretKey = &secretKey + } } } } diff --git a/internal/server/variable.go b/internal/server/variable.go index 14a6399e1..1a6ee7da2 100644 --- a/internal/server/variable.go +++ b/internal/server/variable.go @@ -30,7 +30,7 @@ import ( // Variables holds all runtime variables that can be changed during system operation // Unlike Config, these can be modified at runtime type Variables struct { - SecretKey string `json:"secret_key"` + //SecretKey string `json:"secret_key"` } // VariableStore interface for persistent storage (e.g., Redis) @@ -62,19 +62,20 @@ func InitVariables(store VariableStore) error { variablesOnce.Do(func() { globalVariables = &Variables{} - generatedKey, err := utility.GenerateSecretKey() - if err != nil { - initErr = fmt.Errorf("failed to generate secret key: %w", err) - } - - // Initialize SecretKey - secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey) - if err != nil { - initErr = fmt.Errorf("failed to initialize secret key: %w", err) - } else { - globalVariables.SecretKey = secretKey - common.Info("Secret key initialized from store") - } + //// secret key + //generatedKey, err := utility.GenerateSecretKey() + //if err != nil { + // initErr = fmt.Errorf("failed to generate secret key: %w", err) + //} + // + //// Initialize SecretKey + //secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey) + //if err != nil { + // initErr = fmt.Errorf("failed to initialize secret key: %w", err) + //} else { + // globalVariables.SecretKey = secretKey + // common.Info("Secret key initialized from store") + //} common.Info("Server variables initialized successfully") }) @@ -82,31 +83,39 @@ func InitVariables(store VariableStore) error { } // GetVariables returns the global variables instance -func GetVariables() *Variables { - variablesMu.RLock() - defer variablesMu.RUnlock() - return globalVariables -} +//func GetVariables() *Variables { +// variablesMu.RLock() +// defer variablesMu.RUnlock() +// return globalVariables +//} // GetSecretKey returns the current secret key -func GetSecretKey() string { - variablesMu.RLock() - defer variablesMu.RUnlock() - if globalVariables == nil { - return DefaultSecretKey +func GetSecretKey(store VariableStore) (string, error) { + if globalConfig.Server.SecretKey != nil { + return *globalConfig.Server.SecretKey, nil } - return globalVariables.SecretKey + + generatedKey, err := utility.GenerateSecretKey() + if err != nil { + return "", fmt.Errorf("failed to generate secret key: %w", err) + } + + secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey) + if err != nil { + return "", fmt.Errorf("failed to get secret key: %w", err) + } + return secretKey, nil } // SetSecretKey updates the secret key at runtime -func SetSecretKey(key string) { - variablesMu.Lock() - defer variablesMu.Unlock() - if globalVariables != nil { - globalVariables.SecretKey = key - common.Info("Secret key updated at runtime") - } -} +//func SetSecretKey(key string) { +// variablesMu.Lock() +// defer variablesMu.Unlock() +// if globalVariables != nil { +// globalVariables.SecretKey = key +// common.Info("Secret key updated at runtime") +// } +//} // GetOrCreateKey gets a key from store, or creates it if not exists // - If key exists in store, returns the stored value @@ -178,7 +187,7 @@ func RefreshVariables(store VariableStore) error { return err } if secretKey != "" { - globalVariables.SecretKey = secretKey + //globalVariables.SecretKey = secretKey common.Info("Secret key refreshed from store") } @@ -244,9 +253,9 @@ func SaveToStorage(store VariableStore) error { } // Save SecretKey - if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) { - return fmt.Errorf("failed to save secret key to store") - } + //if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) { + // return fmt.Errorf("failed to save secret key to store") + //} common.Info("Variables saved to storage") return nil diff --git a/internal/service/chat.go b/internal/service/chat.go index 832154ffd..f386d7279 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -50,7 +50,8 @@ func NewChatService() *ChatService { // ChatWithKBNames chat with knowledge base names type ChatWithKBNames struct { *entity.Chat - KBNames []string `json:"kb_names"` + KBNames []string `json:"kb_names"` + DatasetIDs []string `json:"dataset_ids"` } // ListChatsResponse list chats response @@ -99,10 +100,11 @@ func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize // Enrich with knowledge base names chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { - kbNames := s.getKBNames(chat.KBIDs) + kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ - Chat: chat, - KBNames: kbNames, + Chat: chat, + KBNames: kbNames, + DatasetIDs: datasetIDs, }) } @@ -165,10 +167,11 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi // Enrich with knowledge base names chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { - kbNames := s.getKBNames(chat.KBIDs) + kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ - Chat: chat, - KBNames: kbNames, + Chat: chat, + KBNames: kbNames, + DatasetIDs: datasetIDs, }) } @@ -178,9 +181,10 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi }, nil } -// getKBNames gets knowledge base names by IDs -func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string { - var names []string +// getDatasetNamesAndIDs gets knowledge base names by IDs +func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) { + var names = make([]string, 0, 0) + var ids = make([]string, 0, 0) for _, kbID := range kbIDs { kbIDStr, ok := kbID.(string) if !ok { @@ -193,9 +197,10 @@ func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string { // Only include valid KBs if kb.Status != nil && *kb.Status == "1" { names = append(names, kb.Name) + ids = append(ids, kbIDStr) } } - return names + return names, ids } // ParameterConfig parameter configuration in prompt_config @@ -485,7 +490,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo } // Get KB names - kbNames := s.getKBNames(chat.KBIDs) + kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs) return &SetDialogResponse{ Chat: chat, @@ -525,7 +530,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo } // Get KB names - kbNames := s.getKBNames(chat.KBIDs) + kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs) return &SetDialogResponse{ Chat: chat, @@ -679,10 +684,9 @@ func (s *ChatService) GetChat(userID string, chatID string) (*GetChatResponse, e // Step 4: Build response with kb_names (same as Python _build_chat_response) // Resolve kb_ids to kb_names - kbNames := s.getKBNames(chat.KBIDs) + kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) // Build dataset_ids from kb_ids (same as Python _resolve_kb_names returns ids) - var datasetIDs []string for _, kbID := range chat.KBIDs { datasetID, ok := kbID.(string) if !ok { diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index dc28e9ed6..206b6e76b 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -221,7 +221,7 @@ type ListChatSessionsResponse struct { } // ListChatSessions lists chat sessions for a dialog -func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*ListChatSessionsResponse, error) { +func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) { // Get user's tenants tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { @@ -231,7 +231,8 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (* // Check if user is the owner of the dialog isOwner := false for _, tenantID := range tenantIDs { - exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, dialogID) + var exists bool + exists, err = s.chatSessionDAO.CheckDialogExists(tenantID, chatID) if err != nil { return nil, err } @@ -243,7 +244,8 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (* // Also check with userID as tenant if !isOwner { - exists, err := s.chatSessionDAO.CheckDialogExists(userID, dialogID) + var exists bool + exists, err = s.chatSessionDAO.CheckDialogExists(userID, chatID) if err != nil { return nil, err } @@ -251,11 +253,11 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (* } if !isOwner { - return nil, errors.New("Only owner of dialog authorized for this operation") + return nil, errors.New("only owner of dialog authorized for this operation") } // List chat sessions - sessions, err := s.chatSessionDAO.ListByDialogID(dialogID) + sessions, err := s.chatSessionDAO.ListByChatID(chatID) if err != nil { return nil, err } diff --git a/internal/service/system.go b/internal/service/system.go index 191487633..bd0e1790f 100644 --- a/internal/service/system.go +++ b/internal/service/system.go @@ -31,14 +31,20 @@ func NewSystemService() *SystemService { // ConfigResponse system configuration response type ConfigResponse struct { - RegisterEnabled int `json:"registerEnabled"` + RegisterEnabled int `json:"registerEnabled"` + DisablePasswordLogin bool `json:"disablePasswordLogin"` } // GetConfig get system configuration func (s *SystemService) GetConfig() (*ConfigResponse, error) { cfg := server.GetConfig() + registerEnabled := 1 + if !cfg.Authentication.RegisterEnabled { + registerEnabled = 0 + } return &ConfigResponse{ - RegisterEnabled: cfg.RegisterEnabled, + RegisterEnabled: registerEnabled, + DisablePasswordLogin: cfg.Authentication.DisablePasswordLogin, }, nil } diff --git a/internal/service/user.go b/internal/service/user.go index 1e550fb88..0d12d11a7 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -29,6 +29,7 @@ import ( "fmt" "hash" "os" + "ragflow/internal/cache" "ragflow/internal/common" "ragflow/internal/entity" "ragflow/internal/server" @@ -104,23 +105,23 @@ type UserResponse struct { // Register user registration func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) { cfg := server.GetConfig() - if cfg.RegisterEnabled == 0 { - return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!") + if !cfg.Authentication.RegisterEnabled { + return nil, common.CodeOperatingError, fmt.Errorf("user registration is disabled") } emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`) if !emailRegex.MatchString(req.Email) { - return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email) + return nil, common.CodeOperatingError, fmt.Errorf("invalid email address: %s", req.Email) } existUser, _ := s.userDAO.GetByEmail(req.Email) if existUser != nil { - return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email) + return nil, common.CodeOperatingError, fmt.Errorf("email: %s has already registered", req.Email) } decryptedPassword, err := s.decryptPassword(req.Password) if err != nil { - return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password") + return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password") } var hashedPassword string @@ -642,8 +643,10 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error) // using itsdangerous URLSafeTimedSerializer to get the actual access_token func (s *UserService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) { // Get secret key from config - variables := server.GetVariables() - secretKey := variables.SecretKey + secretKey, err := server.GetSecretKey(cache.Get()) + if err != nil { + return nil, common.CodeUnauthorized, err + } // Extract access token from authorization header // Equivalent to: access_token = str(jwt.loads(authorization)) in Python diff --git a/test/testcases/test_web_api/test_system_app/test_system_basic.py b/test/testcases/test_web_api/test_system_app/test_system_basic.py index 81b9de4e2..f9443ec23 100644 --- a/test/testcases/test_web_api/test_system_app/test_system_basic.py +++ b/test/testcases/test_web_api/test_system_app/test_system_basic.py @@ -40,13 +40,6 @@ class TestAuthorization: assert res["code"] == expected_code, res assert expected_fragment in res["message"], res - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_auth_invalid_version(self, invalid_auth, expected_code, expected_fragment): - res = system_version(invalid_auth) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - @pytest.mark.p2 @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) def test_auth_invalid_token_list(self, invalid_auth, expected_code, expected_fragment): diff --git a/web/vite.config.ts b/web/vite.config.ts index 59598ded9..b96f425fa 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -92,6 +92,17 @@ export default defineConfig(({ mode }) => { changeOrigin: true, ws: true, }, + '/api/v1/users/me/models': { + target: 'http://127.0.0.1:9380/', + changeOrigin: true, + ws: true, + }, + '^(/api/v1/auth/login)|^(/api/v1/users/me)|^(/api/v1/system/config)|^(/api/v1/system/version)|^(/api/v1/tenants)|^(/api/v1/chats)|^(/api/v1/searches)|^(/api/v1/files)': + { + target: 'http://127.0.0.1:9384/', + changeOrigin: true, + ws: true, + }, '/api': { target: 'http://127.0.0.1:9380/', changeOrigin: true,