Go: fix auth issue in hybrid mode (#14611)

### What problem does this PR solve?

Since secret key get and set logic is updated, the go server also need
to update.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2026-05-07 17:14:22 +08:00
committed by GitHub
parent 5c9124c3ef
commit 94324afee9
15 changed files with 287 additions and 171 deletions

View File

@ -39,7 +39,6 @@ async def ping():
return "pong", 200
@manager.route("/system/version", methods=["GET"]) # noqa: F821
@login_required
def version():
"""
Get the current version of the application.

View File

@ -174,7 +174,8 @@ def _get_or_create_secret_key():
generated_key = secrets.token_hex(32)
secret_key = REDIS_CONN.get_or_create_secret_key("ragflow:system:secret_key", generated_key)
logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
if generated_key == secret_key:
logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
return secret_key
class StorageFactory:

View File

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"net/http"
"ragflow/internal/cache"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/server"
@ -153,8 +154,15 @@ func (h *Handler) Login(c *gin.Context) {
return
}
variables := server.GetVariables()
secretKey := variables.SecretKey
secretKey, err := server.GetSecretKey(cache.Get())
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
})
return
}
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@ -53,20 +53,20 @@ func (dao *ChatSessionDAO) DeleteByID(id string) error {
return DB.Where("id = ?", id).Delete(&entity.ChatSession{}).Error
}
// ListByDialogID lists chat sessions by dialog ID
func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*entity.ChatSession, error) {
// ListByChatID lists chat sessions by chat ID
func (dao *ChatSessionDAO) ListByChatID(chatID string) ([]*entity.ChatSession, error) {
var convs []*entity.ChatSession
err := DB.Where("dialog_id = ?", dialogID).
err := DB.Where("dialog_id = ?", chatID).
Order("create_time DESC").
Find(&convs).Error
return convs, err
}
// CheckDialogExists checks if a dialog exists with given tenant_id and dialog_id
func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, error) {
func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, error) {
var count int64
err := DB.Model(&entity.Chat{}).
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, dialogID, "1").
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, "1").
Count(&count).Error
if err != nil {
return false, err
@ -75,9 +75,9 @@ func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, e
}
// GetDialogByID gets dialog by ID
func (dao *ChatSessionDAO) GetDialogByID(dialogID string) (*entity.Chat, error) {
func (dao *ChatSessionDAO) GetDialogByID(chatID string) (*entity.Chat, error) {
var dialog entity.Chat
err := DB.Where("id = ? AND status = ?", dialogID, "1").First(&dialog).Error
err := DB.Where("id = ? AND status = ?", chatID, "1").First(&dialog).Error
if err != nil {
return nil, err
}

View File

@ -148,9 +148,9 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) {
// @Tags chat_session
// @Accept json
// @Produce json
// @Param dialog_id query string true "dialog ID"
// @Param chat_id query string true "chat ID"
// @Success 200 {object} service.ListChatSessionsResponse
// @Router /v1/conversation/list [get]
// @Router /api/v1/chats/<chat_id>/sessions [get]
func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
@ -159,18 +159,18 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
}
userID := user.ID
// Get dialog_id from query parameter
dialogID := c.Query("dialog_id")
if dialogID == "" {
// Get chat_id from query parameter
chatID := c.Param("chat_id")
if chatID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "dialog_id is required",
"message": "chat_id is required",
})
return
}
// Call service to list chat sessions
result, err := h.chatSessionService.ListChatSessions(userID, dialogID)
result, err := h.chatSessionService.ListChatSessions(userID, chatID)
if err != nil {
// Check if it's an authorization error
if err.Error() == "Only owner of dialog authorized for this operation" {

View File

@ -19,6 +19,7 @@ package handler
import (
"fmt"
"net/http"
"ragflow/internal/cache"
"ragflow/internal/common"
"ragflow/internal/server"
"ragflow/internal/server/local"
@ -72,8 +73,15 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
variables := server.GetVariables()
secretKey := variables.SecretKey
secretKey, err := server.GetSecretKey(cache.Get())
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
"data": false,
})
return
}
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@ -129,8 +137,15 @@ func (h *UserHandler) Login(c *gin.Context) {
}
// Sign the access_token using itsdangerous (compatible with Python)
variables := server.GetVariables()
secretKey := variables.SecretKey
secretKey, err := server.GetSecretKey(cache.Get())
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
"data": false,
})
return
}
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@ -197,8 +212,15 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) {
return
}
variables := server.GetVariables()
secretKey := variables.SecretKey
secretKey, err := server.GetSecretKey(cache.Get())
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
"data": false,
})
return
}
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@ -90,15 +90,15 @@ func (r *Router) Setup(engine *gin.Engine) {
// System endpoints
engine.GET("/v1/system/ping", r.systemHandler.Ping)
engine.GET("/v1/system/config", r.systemHandler.GetConfig)
engine.GET("/api/v1/system/config", r.systemHandler.GetConfig)
engine.GET("/v1/system/configs", r.systemHandler.GetConfigs)
engine.GET("/v1/system/version", r.systemHandler.GetVersion)
engine.GET("/api/v1/system/version", r.systemHandler.GetVersion)
engine.POST("/v1/user/register", r.userHandler.Register)
// User login channels endpoint
engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels)
engine.GET("/api/v1/auth/login/channels", r.userHandler.GetLoginChannels)
// User login by email endpoint
engine.POST("/v1/user/login", r.userHandler.LoginByEmail)
engine.POST("/api/v1/auth/login", r.userHandler.LoginByEmail)
// User logout endpoint
engine.GET("/v1/user/logout", r.userHandler.Logout)
@ -123,14 +123,25 @@ func (r *Router) Setup(engine *gin.Engine) {
// API v1 route group
v1 := authorized.Group("/api/v1")
{
// User routes
//users := v1.Group("/users")
//{
// users.POST("/register", r.userHandler.Register)
// users.POST("/login", r.userHandler.Login)
// users.GET("", r.userHandler.ListUsers)
// users.GET("/:id", r.userHandler.GetUserByID)
//}
// Auth routes
auth := v1.Group("/auth")
{
// User logout endpoint
auth.GET("/logout", r.userHandler.Logout)
}
// Users routes
users := v1.Group("/users")
{
users.GET("/me", r.userHandler.Info)
// User settings endpoint
users.PATCH("/me", r.userHandler.Setting)
}
tenants := v1.Group("/tenants")
{
tenants.GET("", r.tenantHandler.TenantList)
}
// Document routes
documents := v1.Group("/documents")
@ -142,7 +153,15 @@ func (r *Router) Setup(engine *gin.Engine) {
documents.DELETE("/:id", r.documentHandler.DeleteDocument)
}
// RESTful dataset routes
// Chat routes
chats := v1.Group("/chats")
{
chats.GET("", r.chatHandler.ListChats)
chats.GET("/:chat_id", r.chatHandler.GetChat)
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
}
// Dataset routes
datasets := v1.Group("/datasets")
{
datasets.GET("", r.datasetsHandler.ListDatasets)
@ -150,6 +169,26 @@ func (r *Router) Setup(engine *gin.Engine) {
datasets.DELETE("", r.datasetsHandler.DeleteDatasets)
}
// Search routes
searches := v1.Group("/searches")
{
searches.GET("", r.searchHandler.ListSearches)
searches.POST("", r.searchHandler.CreateSearch)
searches.GET("/:search_id", r.searchHandler.GetSearch)
searches.PUT("/:search_id", r.searchHandler.UpdateSearch)
searches.DELETE("/:search_id", r.searchHandler.DeleteSearch)
}
file := v1.Group("/files")
{
file.POST("", r.fileHandler.UploadFile)
file.GET("", r.fileHandler.ListFiles)
file.DELETE("", r.fileHandler.DeleteFiles)
file.POST("/move", r.fileHandler.MoveFiles)
file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors)
file.GET("/:id", r.fileHandler.Download)
}
// Author routes
authors := v1.Group("/authors")
{
@ -167,62 +206,37 @@ func (r *Router) Setup(engine *gin.Engine) {
memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages)
}
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
// message := v1.Group("/messages")
// {
// message.POST("", r.memoryHandler.AddMessage)
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
// message.GET("/search", r.memoryHandler.SearchMessage)
// message.GET("", r.memoryHandler.GetMessages)
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
// }
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
// message := v1.Group("/messages")
// {
// message.POST("", r.memoryHandler.AddMessage)
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
// message.GET("/search", r.memoryHandler.SearchMessage)
// message.GET("", r.memoryHandler.GetMessages)
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
// }
// Skill search routes
skills := v1.Group("/skills")
{
// Skill Space management
skills.GET("/spaces", r.skillSearchHandler.ListSpaces)
skills.POST("/spaces", r.skillSearchHandler.CreateSpace)
skills.GET("/spaces/:space_id", r.skillSearchHandler.GetSpace)
skills.PUT("/spaces/:space_id", r.skillSearchHandler.UpdateSpace)
skills.DELETE("/spaces/:space_id", r.skillSearchHandler.DeleteSpace)
skills.GET("/space/by-folder", r.skillSearchHandler.GetSpaceByFolder)
// Skill search config
skills.GET("/config", r.skillSearchHandler.GetConfig)
skills.POST("/config", r.skillSearchHandler.UpdateConfig)
// Skill search and indexing
skills.POST("/search", r.skillSearchHandler.Search)
skills.POST("/index", r.skillSearchHandler.IndexSkills)
skills.DELETE("/index", r.skillSearchHandler.DeleteSkillIndex)
skills.POST("/reindex", r.skillSearchHandler.Reindex)
}
chats := v1.Group("/chats")
// Skill search routes
skills := v1.Group("/skills")
{
chats.GET("", r.chatHandler.ListChats)
chats.GET("/:chat_id", r.chatHandler.GetChat)
}
// Skill Space management
skills.GET("/spaces", r.skillSearchHandler.ListSpaces)
skills.POST("/spaces", r.skillSearchHandler.CreateSpace)
skills.GET("/spaces/:space_id", r.skillSearchHandler.GetSpace)
skills.PUT("/spaces/:space_id", r.skillSearchHandler.UpdateSpace)
skills.DELETE("/spaces/:space_id", r.skillSearchHandler.DeleteSpace)
skills.GET("/space/by-folder", r.skillSearchHandler.GetSpaceByFolder)
searches := v1.Group("/searches")
{
searches.GET("", r.searchHandler.ListSearches)
searches.POST("", r.searchHandler.CreateSearch)
searches.GET("/:search_id", r.searchHandler.GetSearch)
searches.PUT("/:search_id", r.searchHandler.UpdateSearch)
searches.DELETE("/:search_id", r.searchHandler.DeleteSearch)
}
// Skill search config
skills.GET("/config", r.skillSearchHandler.GetConfig)
skills.POST("/config", r.skillSearchHandler.UpdateConfig)
file := v1.Group("/files")
{
file.POST("", r.fileHandler.UploadFile)
file.GET("", r.fileHandler.ListFiles)
file.DELETE("", r.fileHandler.DeleteFiles)
file.POST("/move", r.fileHandler.MoveFiles)
file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors)
file.GET("/:id", r.fileHandler.Download)
// Skill search and indexing
skills.POST("/search", r.skillSearchHandler.Search)
skills.POST("/index", r.skillSearchHandler.IndexSkills)
skills.DELETE("/index", r.skillSearchHandler.DeleteSkillIndex)
skills.POST("/reindex", r.skillSearchHandler.Reindex)
}
// provider pool route group
@ -256,7 +270,6 @@ func (r *Router) Setup(engine *gin.Engine) {
system := v1.Group("/system")
{
system.GET("/version", r.systemHandler.GetVersion)
system.GET("/configs", r.systemHandler.GetConfigs)
log := system.Group("/log")
{

View File

@ -36,6 +36,7 @@ const DefaultConnectTimeout = 5 * time.Second
// Config application configuration
type Config struct {
Server ServerConfig `mapstructure:"server"`
Authentication AuthenticationConfig `mapstructure:"authentication"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Log LogConfig `mapstructure:"log"`
@ -55,6 +56,11 @@ type AdminConfig struct {
Port int `mapstructure:"http_port"`
}
type AuthenticationConfig struct {
DisablePasswordLogin bool `mapstructure:"disable_password_login"`
RegisterEnabled bool `mapstructure:"register_enabled"`
}
type DefaultSuperUser struct {
Email string `mapstructure:"email"`
Password string `mapstructure:"password"`
@ -91,8 +97,9 @@ type OAuthConfig struct {
// ServerConfig server configuration
type ServerConfig struct {
Mode string `mapstructure:"mode"` // debug, release
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug, release
Port int `mapstructure:"port"`
SecretKey *string `mapstructure:"secret_key"`
}
// DatabaseConfig database configuration
@ -372,6 +379,31 @@ func Init(configPath string) error {
}
func FromEnvironments() error {
// Secret key
if envVal := os.Getenv("RAGFLOW_SECRET_KEY"); envVal != "" {
globalConfig.Server.SecretKey = &envVal
}
// Load REGISTER_ENABLED from environment variable (default: true)
if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" {
str := strings.ToLower(envVal)
if str == "true" || str == "1" || str == "yes" {
globalConfig.Authentication.RegisterEnabled = true
} else {
globalConfig.Authentication.RegisterEnabled = false
}
}
// Load DISABLE_PASSWORD_LOGIN from environment variable (default: false)
if envVal := os.Getenv("DISABLE_PASSWORD_LOGIN"); envVal != "" {
str := strings.ToLower(envVal)
if str == "true" || str == "1" || str == "yes" {
globalConfig.Authentication.DisablePasswordLogin = true
} else {
globalConfig.Authentication.DisablePasswordLogin = false
}
}
// Doc engine
docEngine := strings.ToLower(os.Getenv("DOC_ENGINE"))
switch docEngine {
@ -535,14 +567,23 @@ func FromConfigFile(configPath string) error {
globalConfig.Admin.Port += 2
}
// Load REGISTER_ENABLED from environment variable (default: 1)
registerEnabled := 1
if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" {
if parsed, err := strconv.Atoi(envVal); err == nil {
registerEnabled = parsed
// authentication section
if globalConfig != nil {
// Try to map from mysql section
globalConfig.Authentication.DisablePasswordLogin = false
globalConfig.Authentication.RegisterEnabled = true
if v.IsSet("authentication") {
authenticationConfig := v.Sub("authentication")
if authenticationConfig != nil {
if authenticationConfig.IsSet("disable_password_login") {
globalConfig.Authentication.DisablePasswordLogin = authenticationConfig.GetBool("disable_password_login")
}
if authenticationConfig.IsSet("enable_register") {
globalConfig.Authentication.RegisterEnabled = authenticationConfig.GetBool("enable_register")
}
}
}
}
globalConfig.RegisterEnabled = registerEnabled
// If we loaded service_conf.yaml, map mysql fields to DatabaseConfig
if globalConfig != nil && globalConfig.Database.Host == "" {
@ -573,6 +614,10 @@ func FromConfigFile(configPath string) error {
if globalConfig.Server.Mode == "" {
globalConfig.Server.Mode = "release"
}
secretKey := ragflowConfig.GetString("secret_key")
if secretKey != "" {
globalConfig.Server.SecretKey = &secretKey
}
}
}
}

View File

@ -30,7 +30,7 @@ import (
// Variables holds all runtime variables that can be changed during system operation
// Unlike Config, these can be modified at runtime
type Variables struct {
SecretKey string `json:"secret_key"`
//SecretKey string `json:"secret_key"`
}
// VariableStore interface for persistent storage (e.g., Redis)
@ -62,19 +62,20 @@ func InitVariables(store VariableStore) error {
variablesOnce.Do(func() {
globalVariables = &Variables{}
generatedKey, err := utility.GenerateSecretKey()
if err != nil {
initErr = fmt.Errorf("failed to generate secret key: %w", err)
}
// Initialize SecretKey
secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey)
if err != nil {
initErr = fmt.Errorf("failed to initialize secret key: %w", err)
} else {
globalVariables.SecretKey = secretKey
common.Info("Secret key initialized from store")
}
//// secret key
//generatedKey, err := utility.GenerateSecretKey()
//if err != nil {
// initErr = fmt.Errorf("failed to generate secret key: %w", err)
//}
//
//// Initialize SecretKey
//secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey)
//if err != nil {
// initErr = fmt.Errorf("failed to initialize secret key: %w", err)
//} else {
// globalVariables.SecretKey = secretKey
// common.Info("Secret key initialized from store")
//}
common.Info("Server variables initialized successfully")
})
@ -82,31 +83,39 @@ func InitVariables(store VariableStore) error {
}
// GetVariables returns the global variables instance
func GetVariables() *Variables {
variablesMu.RLock()
defer variablesMu.RUnlock()
return globalVariables
}
//func GetVariables() *Variables {
// variablesMu.RLock()
// defer variablesMu.RUnlock()
// return globalVariables
//}
// GetSecretKey returns the current secret key
func GetSecretKey() string {
variablesMu.RLock()
defer variablesMu.RUnlock()
if globalVariables == nil {
return DefaultSecretKey
func GetSecretKey(store VariableStore) (string, error) {
if globalConfig.Server.SecretKey != nil {
return *globalConfig.Server.SecretKey, nil
}
return globalVariables.SecretKey
generatedKey, err := utility.GenerateSecretKey()
if err != nil {
return "", fmt.Errorf("failed to generate secret key: %w", err)
}
secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey)
if err != nil {
return "", fmt.Errorf("failed to get secret key: %w", err)
}
return secretKey, nil
}
// SetSecretKey updates the secret key at runtime
func SetSecretKey(key string) {
variablesMu.Lock()
defer variablesMu.Unlock()
if globalVariables != nil {
globalVariables.SecretKey = key
common.Info("Secret key updated at runtime")
}
}
//func SetSecretKey(key string) {
// variablesMu.Lock()
// defer variablesMu.Unlock()
// if globalVariables != nil {
// globalVariables.SecretKey = key
// common.Info("Secret key updated at runtime")
// }
//}
// GetOrCreateKey gets a key from store, or creates it if not exists
// - If key exists in store, returns the stored value
@ -178,7 +187,7 @@ func RefreshVariables(store VariableStore) error {
return err
}
if secretKey != "" {
globalVariables.SecretKey = secretKey
//globalVariables.SecretKey = secretKey
common.Info("Secret key refreshed from store")
}
@ -244,9 +253,9 @@ func SaveToStorage(store VariableStore) error {
}
// Save SecretKey
if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) {
return fmt.Errorf("failed to save secret key to store")
}
//if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) {
// return fmt.Errorf("failed to save secret key to store")
//}
common.Info("Variables saved to storage")
return nil

View File

@ -50,7 +50,8 @@ func NewChatService() *ChatService {
// ChatWithKBNames chat with knowledge base names
type ChatWithKBNames struct {
*entity.Chat
KBNames []string `json:"kb_names"`
KBNames []string `json:"kb_names"`
DatasetIDs []string `json:"dataset_ids"`
}
// ListChatsResponse list chats response
@ -99,10 +100,11 @@ func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize
// Enrich with knowledge base names
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
for _, chat := range chats {
kbNames := s.getKBNames(chat.KBIDs)
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
Chat: chat,
KBNames: kbNames,
Chat: chat,
KBNames: kbNames,
DatasetIDs: datasetIDs,
})
}
@ -165,10 +167,11 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
// Enrich with knowledge base names
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
for _, chat := range chats {
kbNames := s.getKBNames(chat.KBIDs)
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
Chat: chat,
KBNames: kbNames,
Chat: chat,
KBNames: kbNames,
DatasetIDs: datasetIDs,
})
}
@ -178,9 +181,10 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
}, nil
}
// getKBNames gets knowledge base names by IDs
func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string {
var names []string
// getDatasetNamesAndIDs gets knowledge base names by IDs
func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) {
var names = make([]string, 0, 0)
var ids = make([]string, 0, 0)
for _, kbID := range kbIDs {
kbIDStr, ok := kbID.(string)
if !ok {
@ -193,9 +197,10 @@ func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string {
// Only include valid KBs
if kb.Status != nil && *kb.Status == "1" {
names = append(names, kb.Name)
ids = append(ids, kbIDStr)
}
}
return names
return names, ids
}
// ParameterConfig parameter configuration in prompt_config
@ -485,7 +490,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
}
// Get KB names
kbNames := s.getKBNames(chat.KBIDs)
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
return &SetDialogResponse{
Chat: chat,
@ -525,7 +530,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
}
// Get KB names
kbNames := s.getKBNames(chat.KBIDs)
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
return &SetDialogResponse{
Chat: chat,
@ -679,10 +684,9 @@ func (s *ChatService) GetChat(userID string, chatID string) (*GetChatResponse, e
// Step 4: Build response with kb_names (same as Python _build_chat_response)
// Resolve kb_ids to kb_names
kbNames := s.getKBNames(chat.KBIDs)
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
// Build dataset_ids from kb_ids (same as Python _resolve_kb_names returns ids)
var datasetIDs []string
for _, kbID := range chat.KBIDs {
datasetID, ok := kbID.(string)
if !ok {

View File

@ -221,7 +221,7 @@ type ListChatSessionsResponse struct {
}
// ListChatSessions lists chat sessions for a dialog
func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*ListChatSessionsResponse, error) {
func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) {
// Get user's tenants
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
@ -231,7 +231,8 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*
// Check if user is the owner of the dialog
isOwner := false
for _, tenantID := range tenantIDs {
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, dialogID)
var exists bool
exists, err = s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
if err != nil {
return nil, err
}
@ -243,7 +244,8 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*
// Also check with userID as tenant
if !isOwner {
exists, err := s.chatSessionDAO.CheckDialogExists(userID, dialogID)
var exists bool
exists, err = s.chatSessionDAO.CheckDialogExists(userID, chatID)
if err != nil {
return nil, err
}
@ -251,11 +253,11 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*
}
if !isOwner {
return nil, errors.New("Only owner of dialog authorized for this operation")
return nil, errors.New("only owner of dialog authorized for this operation")
}
// List chat sessions
sessions, err := s.chatSessionDAO.ListByDialogID(dialogID)
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
if err != nil {
return nil, err
}

View File

@ -31,14 +31,20 @@ func NewSystemService() *SystemService {
// ConfigResponse system configuration response
type ConfigResponse struct {
RegisterEnabled int `json:"registerEnabled"`
RegisterEnabled int `json:"registerEnabled"`
DisablePasswordLogin bool `json:"disablePasswordLogin"`
}
// GetConfig get system configuration
func (s *SystemService) GetConfig() (*ConfigResponse, error) {
cfg := server.GetConfig()
registerEnabled := 1
if !cfg.Authentication.RegisterEnabled {
registerEnabled = 0
}
return &ConfigResponse{
RegisterEnabled: cfg.RegisterEnabled,
RegisterEnabled: registerEnabled,
DisablePasswordLogin: cfg.Authentication.DisablePasswordLogin,
}, nil
}

View File

@ -29,6 +29,7 @@ import (
"fmt"
"hash"
"os"
"ragflow/internal/cache"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/server"
@ -104,23 +105,23 @@ type UserResponse struct {
// Register user registration
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
cfg := server.GetConfig()
if cfg.RegisterEnabled == 0 {
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
if !cfg.Authentication.RegisterEnabled {
return nil, common.CodeOperatingError, fmt.Errorf("user registration is disabled")
}
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
if !emailRegex.MatchString(req.Email) {
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
return nil, common.CodeOperatingError, fmt.Errorf("invalid email address: %s", req.Email)
}
existUser, _ := s.userDAO.GetByEmail(req.Email)
if existUser != nil {
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
return nil, common.CodeOperatingError, fmt.Errorf("email: %s has already registered", req.Email)
}
decryptedPassword, err := s.decryptPassword(req.Password)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password")
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
}
var hashedPassword string
@ -642,8 +643,10 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error)
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
func (s *UserService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) {
// Get secret key from config
variables := server.GetVariables()
secretKey := variables.SecretKey
secretKey, err := server.GetSecretKey(cache.Get())
if err != nil {
return nil, common.CodeUnauthorized, err
}
// Extract access token from authorization header
// Equivalent to: access_token = str(jwt.loads(authorization)) in Python

View File

@ -40,13 +40,6 @@ class TestAuthorization:
assert res["code"] == expected_code, res
assert expected_fragment in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
def test_auth_invalid_version(self, invalid_auth, expected_code, expected_fragment):
res = system_version(invalid_auth)
assert res["code"] == expected_code, res
assert expected_fragment in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
def test_auth_invalid_token_list(self, invalid_auth, expected_code, expected_fragment):

View File

@ -92,6 +92,17 @@ export default defineConfig(({ mode }) => {
changeOrigin: true,
ws: true,
},
'/api/v1/users/me/models': {
target: 'http://127.0.0.1:9380/',
changeOrigin: true,
ws: true,
},
'^(/api/v1/auth/login)|^(/api/v1/users/me)|^(/api/v1/system/config)|^(/api/v1/system/version)|^(/api/v1/tenants)|^(/api/v1/chats)|^(/api/v1/searches)|^(/api/v1/files)':
{
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/api': {
target: 'http://127.0.0.1:9380/',
changeOrigin: true,