mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-30 20:47:29 +08:00
Go: fix auth issue in hybrid mode (#14611)
### What problem does this PR solve? Since secret key get and set logic is updated, the go server also need to update. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -39,7 +39,6 @@ async def ping():
|
||||
return "pong", 200
|
||||
|
||||
@manager.route("/system/version", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def version():
|
||||
"""
|
||||
Get the current version of the application.
|
||||
|
||||
@ -174,7 +174,8 @@ def _get_or_create_secret_key():
|
||||
|
||||
generated_key = secrets.token_hex(32)
|
||||
secret_key = REDIS_CONN.get_or_create_secret_key("ragflow:system:secret_key", generated_key)
|
||||
logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
|
||||
if generated_key == secret_key:
|
||||
logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
|
||||
return secret_key
|
||||
|
||||
class StorageFactory:
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/cache"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/server"
|
||||
@ -153,8 +154,15 @@ func (h *Handler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
secretKey, err := server.GetSecretKey(cache.Get())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@ -53,20 +53,20 @@ func (dao *ChatSessionDAO) DeleteByID(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&entity.ChatSession{}).Error
|
||||
}
|
||||
|
||||
// ListByDialogID lists chat sessions by dialog ID
|
||||
func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*entity.ChatSession, error) {
|
||||
// ListByChatID lists chat sessions by chat ID
|
||||
func (dao *ChatSessionDAO) ListByChatID(chatID string) ([]*entity.ChatSession, error) {
|
||||
var convs []*entity.ChatSession
|
||||
err := DB.Where("dialog_id = ?", dialogID).
|
||||
err := DB.Where("dialog_id = ?", chatID).
|
||||
Order("create_time DESC").
|
||||
Find(&convs).Error
|
||||
return convs, err
|
||||
}
|
||||
|
||||
// CheckDialogExists checks if a dialog exists with given tenant_id and dialog_id
|
||||
func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, error) {
|
||||
func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, error) {
|
||||
var count int64
|
||||
err := DB.Model(&entity.Chat{}).
|
||||
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, dialogID, "1").
|
||||
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, "1").
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -75,9 +75,9 @@ func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, e
|
||||
}
|
||||
|
||||
// GetDialogByID gets dialog by ID
|
||||
func (dao *ChatSessionDAO) GetDialogByID(dialogID string) (*entity.Chat, error) {
|
||||
func (dao *ChatSessionDAO) GetDialogByID(chatID string) (*entity.Chat, error) {
|
||||
var dialog entity.Chat
|
||||
err := DB.Where("id = ? AND status = ?", dialogID, "1").First(&dialog).Error
|
||||
err := DB.Where("id = ? AND status = ?", chatID, "1").First(&dialog).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -148,9 +148,9 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) {
|
||||
// @Tags chat_session
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param dialog_id query string true "dialog ID"
|
||||
// @Param chat_id query string true "chat ID"
|
||||
// @Success 200 {object} service.ListChatSessionsResponse
|
||||
// @Router /v1/conversation/list [get]
|
||||
// @Router /api/v1/chats/<chat_id>/sessions [get]
|
||||
func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
@ -159,18 +159,18 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Get dialog_id from query parameter
|
||||
dialogID := c.Query("dialog_id")
|
||||
if dialogID == "" {
|
||||
// Get chat_id from query parameter
|
||||
chatID := c.Param("chat_id")
|
||||
if chatID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "dialog_id is required",
|
||||
"message": "chat_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to list chat sessions
|
||||
result, err := h.chatSessionService.ListChatSessions(userID, dialogID)
|
||||
result, err := h.chatSessionService.ListChatSessions(userID, chatID)
|
||||
if err != nil {
|
||||
// Check if it's an authorization error
|
||||
if err.Error() == "Only owner of dialog authorized for this operation" {
|
||||
|
||||
@ -19,6 +19,7 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/cache"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/server/local"
|
||||
@ -72,8 +73,15 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
secretKey, err := server.GetSecretKey(cache.Get())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@ -129,8 +137,15 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Sign the access_token using itsdangerous (compatible with Python)
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
secretKey, err := server.GetSecretKey(cache.Get())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@ -197,8 +212,15 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
secretKey, err := server.GetSecretKey(cache.Get())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": fmt.Sprintf("Failed to get secret key: %s", err.Error()),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@ -90,15 +90,15 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
|
||||
// System endpoints
|
||||
engine.GET("/v1/system/ping", r.systemHandler.Ping)
|
||||
engine.GET("/v1/system/config", r.systemHandler.GetConfig)
|
||||
engine.GET("/api/v1/system/config", r.systemHandler.GetConfig)
|
||||
engine.GET("/v1/system/configs", r.systemHandler.GetConfigs)
|
||||
engine.GET("/v1/system/version", r.systemHandler.GetVersion)
|
||||
engine.GET("/api/v1/system/version", r.systemHandler.GetVersion)
|
||||
engine.POST("/v1/user/register", r.userHandler.Register)
|
||||
// User login channels endpoint
|
||||
engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels)
|
||||
engine.GET("/api/v1/auth/login/channels", r.userHandler.GetLoginChannels)
|
||||
|
||||
// User login by email endpoint
|
||||
engine.POST("/v1/user/login", r.userHandler.LoginByEmail)
|
||||
engine.POST("/api/v1/auth/login", r.userHandler.LoginByEmail)
|
||||
|
||||
// User logout endpoint
|
||||
engine.GET("/v1/user/logout", r.userHandler.Logout)
|
||||
@ -123,14 +123,25 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// API v1 route group
|
||||
v1 := authorized.Group("/api/v1")
|
||||
{
|
||||
// User routes
|
||||
//users := v1.Group("/users")
|
||||
//{
|
||||
// users.POST("/register", r.userHandler.Register)
|
||||
// users.POST("/login", r.userHandler.Login)
|
||||
// users.GET("", r.userHandler.ListUsers)
|
||||
// users.GET("/:id", r.userHandler.GetUserByID)
|
||||
//}
|
||||
// Auth routes
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
// User logout endpoint
|
||||
auth.GET("/logout", r.userHandler.Logout)
|
||||
}
|
||||
|
||||
// Users routes
|
||||
users := v1.Group("/users")
|
||||
{
|
||||
users.GET("/me", r.userHandler.Info)
|
||||
// User settings endpoint
|
||||
users.PATCH("/me", r.userHandler.Setting)
|
||||
}
|
||||
|
||||
tenants := v1.Group("/tenants")
|
||||
{
|
||||
tenants.GET("", r.tenantHandler.TenantList)
|
||||
}
|
||||
|
||||
// Document routes
|
||||
documents := v1.Group("/documents")
|
||||
@ -142,7 +153,15 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
documents.DELETE("/:id", r.documentHandler.DeleteDocument)
|
||||
}
|
||||
|
||||
// RESTful dataset routes
|
||||
// Chat routes
|
||||
chats := v1.Group("/chats")
|
||||
{
|
||||
chats.GET("", r.chatHandler.ListChats)
|
||||
chats.GET("/:chat_id", r.chatHandler.GetChat)
|
||||
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
|
||||
}
|
||||
|
||||
// Dataset routes
|
||||
datasets := v1.Group("/datasets")
|
||||
{
|
||||
datasets.GET("", r.datasetsHandler.ListDatasets)
|
||||
@ -150,6 +169,26 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.DELETE("", r.datasetsHandler.DeleteDatasets)
|
||||
}
|
||||
|
||||
// Search routes
|
||||
searches := v1.Group("/searches")
|
||||
{
|
||||
searches.GET("", r.searchHandler.ListSearches)
|
||||
searches.POST("", r.searchHandler.CreateSearch)
|
||||
searches.GET("/:search_id", r.searchHandler.GetSearch)
|
||||
searches.PUT("/:search_id", r.searchHandler.UpdateSearch)
|
||||
searches.DELETE("/:search_id", r.searchHandler.DeleteSearch)
|
||||
}
|
||||
|
||||
file := v1.Group("/files")
|
||||
{
|
||||
file.POST("", r.fileHandler.UploadFile)
|
||||
file.GET("", r.fileHandler.ListFiles)
|
||||
file.DELETE("", r.fileHandler.DeleteFiles)
|
||||
file.POST("/move", r.fileHandler.MoveFiles)
|
||||
file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors)
|
||||
file.GET("/:id", r.fileHandler.Download)
|
||||
}
|
||||
|
||||
// Author routes
|
||||
authors := v1.Group("/authors")
|
||||
{
|
||||
@ -167,62 +206,37 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages)
|
||||
}
|
||||
|
||||
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
|
||||
// message := v1.Group("/messages")
|
||||
// {
|
||||
// message.POST("", r.memoryHandler.AddMessage)
|
||||
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
|
||||
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
|
||||
// message.GET("/search", r.memoryHandler.SearchMessage)
|
||||
// message.GET("", r.memoryHandler.GetMessages)
|
||||
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
|
||||
// }
|
||||
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
|
||||
// message := v1.Group("/messages")
|
||||
// {
|
||||
// message.POST("", r.memoryHandler.AddMessage)
|
||||
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
|
||||
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
|
||||
// message.GET("/search", r.memoryHandler.SearchMessage)
|
||||
// message.GET("", r.memoryHandler.GetMessages)
|
||||
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
|
||||
// }
|
||||
|
||||
// Skill search routes
|
||||
skills := v1.Group("/skills")
|
||||
{
|
||||
// Skill Space management
|
||||
skills.GET("/spaces", r.skillSearchHandler.ListSpaces)
|
||||
skills.POST("/spaces", r.skillSearchHandler.CreateSpace)
|
||||
skills.GET("/spaces/:space_id", r.skillSearchHandler.GetSpace)
|
||||
skills.PUT("/spaces/:space_id", r.skillSearchHandler.UpdateSpace)
|
||||
skills.DELETE("/spaces/:space_id", r.skillSearchHandler.DeleteSpace)
|
||||
skills.GET("/space/by-folder", r.skillSearchHandler.GetSpaceByFolder)
|
||||
|
||||
// Skill search config
|
||||
skills.GET("/config", r.skillSearchHandler.GetConfig)
|
||||
skills.POST("/config", r.skillSearchHandler.UpdateConfig)
|
||||
|
||||
// Skill search and indexing
|
||||
skills.POST("/search", r.skillSearchHandler.Search)
|
||||
skills.POST("/index", r.skillSearchHandler.IndexSkills)
|
||||
skills.DELETE("/index", r.skillSearchHandler.DeleteSkillIndex)
|
||||
skills.POST("/reindex", r.skillSearchHandler.Reindex)
|
||||
}
|
||||
|
||||
chats := v1.Group("/chats")
|
||||
// Skill search routes
|
||||
skills := v1.Group("/skills")
|
||||
{
|
||||
chats.GET("", r.chatHandler.ListChats)
|
||||
chats.GET("/:chat_id", r.chatHandler.GetChat)
|
||||
}
|
||||
// Skill Space management
|
||||
skills.GET("/spaces", r.skillSearchHandler.ListSpaces)
|
||||
skills.POST("/spaces", r.skillSearchHandler.CreateSpace)
|
||||
skills.GET("/spaces/:space_id", r.skillSearchHandler.GetSpace)
|
||||
skills.PUT("/spaces/:space_id", r.skillSearchHandler.UpdateSpace)
|
||||
skills.DELETE("/spaces/:space_id", r.skillSearchHandler.DeleteSpace)
|
||||
skills.GET("/space/by-folder", r.skillSearchHandler.GetSpaceByFolder)
|
||||
|
||||
searches := v1.Group("/searches")
|
||||
{
|
||||
searches.GET("", r.searchHandler.ListSearches)
|
||||
searches.POST("", r.searchHandler.CreateSearch)
|
||||
searches.GET("/:search_id", r.searchHandler.GetSearch)
|
||||
searches.PUT("/:search_id", r.searchHandler.UpdateSearch)
|
||||
searches.DELETE("/:search_id", r.searchHandler.DeleteSearch)
|
||||
}
|
||||
// Skill search config
|
||||
skills.GET("/config", r.skillSearchHandler.GetConfig)
|
||||
skills.POST("/config", r.skillSearchHandler.UpdateConfig)
|
||||
|
||||
file := v1.Group("/files")
|
||||
{
|
||||
file.POST("", r.fileHandler.UploadFile)
|
||||
file.GET("", r.fileHandler.ListFiles)
|
||||
file.DELETE("", r.fileHandler.DeleteFiles)
|
||||
file.POST("/move", r.fileHandler.MoveFiles)
|
||||
file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors)
|
||||
file.GET("/:id", r.fileHandler.Download)
|
||||
// Skill search and indexing
|
||||
skills.POST("/search", r.skillSearchHandler.Search)
|
||||
skills.POST("/index", r.skillSearchHandler.IndexSkills)
|
||||
skills.DELETE("/index", r.skillSearchHandler.DeleteSkillIndex)
|
||||
skills.POST("/reindex", r.skillSearchHandler.Reindex)
|
||||
}
|
||||
|
||||
// provider pool route group
|
||||
@ -256,7 +270,6 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
|
||||
system := v1.Group("/system")
|
||||
{
|
||||
system.GET("/version", r.systemHandler.GetVersion)
|
||||
system.GET("/configs", r.systemHandler.GetConfigs)
|
||||
log := system.Group("/log")
|
||||
{
|
||||
|
||||
@ -36,6 +36,7 @@ const DefaultConnectTimeout = 5 * time.Second
|
||||
// Config application configuration
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Authentication AuthenticationConfig `mapstructure:"authentication"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
@ -55,6 +56,11 @@ type AdminConfig struct {
|
||||
Port int `mapstructure:"http_port"`
|
||||
}
|
||||
|
||||
type AuthenticationConfig struct {
|
||||
DisablePasswordLogin bool `mapstructure:"disable_password_login"`
|
||||
RegisterEnabled bool `mapstructure:"register_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSuperUser struct {
|
||||
Email string `mapstructure:"email"`
|
||||
Password string `mapstructure:"password"`
|
||||
@ -91,8 +97,9 @@ type OAuthConfig struct {
|
||||
|
||||
// ServerConfig server configuration
|
||||
type ServerConfig struct {
|
||||
Mode string `mapstructure:"mode"` // debug, release
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug, release
|
||||
Port int `mapstructure:"port"`
|
||||
SecretKey *string `mapstructure:"secret_key"`
|
||||
}
|
||||
|
||||
// DatabaseConfig database configuration
|
||||
@ -372,6 +379,31 @@ func Init(configPath string) error {
|
||||
}
|
||||
|
||||
func FromEnvironments() error {
|
||||
// Secret key
|
||||
if envVal := os.Getenv("RAGFLOW_SECRET_KEY"); envVal != "" {
|
||||
globalConfig.Server.SecretKey = &envVal
|
||||
}
|
||||
|
||||
// Load REGISTER_ENABLED from environment variable (default: true)
|
||||
if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" {
|
||||
str := strings.ToLower(envVal)
|
||||
if str == "true" || str == "1" || str == "yes" {
|
||||
globalConfig.Authentication.RegisterEnabled = true
|
||||
} else {
|
||||
globalConfig.Authentication.RegisterEnabled = false
|
||||
}
|
||||
}
|
||||
|
||||
// Load DISABLE_PASSWORD_LOGIN from environment variable (default: false)
|
||||
if envVal := os.Getenv("DISABLE_PASSWORD_LOGIN"); envVal != "" {
|
||||
str := strings.ToLower(envVal)
|
||||
if str == "true" || str == "1" || str == "yes" {
|
||||
globalConfig.Authentication.DisablePasswordLogin = true
|
||||
} else {
|
||||
globalConfig.Authentication.DisablePasswordLogin = false
|
||||
}
|
||||
}
|
||||
|
||||
// Doc engine
|
||||
docEngine := strings.ToLower(os.Getenv("DOC_ENGINE"))
|
||||
switch docEngine {
|
||||
@ -535,14 +567,23 @@ func FromConfigFile(configPath string) error {
|
||||
globalConfig.Admin.Port += 2
|
||||
}
|
||||
|
||||
// Load REGISTER_ENABLED from environment variable (default: 1)
|
||||
registerEnabled := 1
|
||||
if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" {
|
||||
if parsed, err := strconv.Atoi(envVal); err == nil {
|
||||
registerEnabled = parsed
|
||||
// authentication section
|
||||
if globalConfig != nil {
|
||||
// Try to map from mysql section
|
||||
globalConfig.Authentication.DisablePasswordLogin = false
|
||||
globalConfig.Authentication.RegisterEnabled = true
|
||||
if v.IsSet("authentication") {
|
||||
authenticationConfig := v.Sub("authentication")
|
||||
if authenticationConfig != nil {
|
||||
if authenticationConfig.IsSet("disable_password_login") {
|
||||
globalConfig.Authentication.DisablePasswordLogin = authenticationConfig.GetBool("disable_password_login")
|
||||
}
|
||||
if authenticationConfig.IsSet("enable_register") {
|
||||
globalConfig.Authentication.RegisterEnabled = authenticationConfig.GetBool("enable_register")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
globalConfig.RegisterEnabled = registerEnabled
|
||||
|
||||
// If we loaded service_conf.yaml, map mysql fields to DatabaseConfig
|
||||
if globalConfig != nil && globalConfig.Database.Host == "" {
|
||||
@ -573,6 +614,10 @@ func FromConfigFile(configPath string) error {
|
||||
if globalConfig.Server.Mode == "" {
|
||||
globalConfig.Server.Mode = "release"
|
||||
}
|
||||
secretKey := ragflowConfig.GetString("secret_key")
|
||||
if secretKey != "" {
|
||||
globalConfig.Server.SecretKey = &secretKey
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ import (
|
||||
// Variables holds all runtime variables that can be changed during system operation
|
||||
// Unlike Config, these can be modified at runtime
|
||||
type Variables struct {
|
||||
SecretKey string `json:"secret_key"`
|
||||
//SecretKey string `json:"secret_key"`
|
||||
}
|
||||
|
||||
// VariableStore interface for persistent storage (e.g., Redis)
|
||||
@ -62,19 +62,20 @@ func InitVariables(store VariableStore) error {
|
||||
variablesOnce.Do(func() {
|
||||
globalVariables = &Variables{}
|
||||
|
||||
generatedKey, err := utility.GenerateSecretKey()
|
||||
if err != nil {
|
||||
initErr = fmt.Errorf("failed to generate secret key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize SecretKey
|
||||
secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey)
|
||||
if err != nil {
|
||||
initErr = fmt.Errorf("failed to initialize secret key: %w", err)
|
||||
} else {
|
||||
globalVariables.SecretKey = secretKey
|
||||
common.Info("Secret key initialized from store")
|
||||
}
|
||||
//// secret key
|
||||
//generatedKey, err := utility.GenerateSecretKey()
|
||||
//if err != nil {
|
||||
// initErr = fmt.Errorf("failed to generate secret key: %w", err)
|
||||
//}
|
||||
//
|
||||
//// Initialize SecretKey
|
||||
//secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey)
|
||||
//if err != nil {
|
||||
// initErr = fmt.Errorf("failed to initialize secret key: %w", err)
|
||||
//} else {
|
||||
// globalVariables.SecretKey = secretKey
|
||||
// common.Info("Secret key initialized from store")
|
||||
//}
|
||||
|
||||
common.Info("Server variables initialized successfully")
|
||||
})
|
||||
@ -82,31 +83,39 @@ func InitVariables(store VariableStore) error {
|
||||
}
|
||||
|
||||
// GetVariables returns the global variables instance
|
||||
func GetVariables() *Variables {
|
||||
variablesMu.RLock()
|
||||
defer variablesMu.RUnlock()
|
||||
return globalVariables
|
||||
}
|
||||
//func GetVariables() *Variables {
|
||||
// variablesMu.RLock()
|
||||
// defer variablesMu.RUnlock()
|
||||
// return globalVariables
|
||||
//}
|
||||
|
||||
// GetSecretKey returns the current secret key
|
||||
func GetSecretKey() string {
|
||||
variablesMu.RLock()
|
||||
defer variablesMu.RUnlock()
|
||||
if globalVariables == nil {
|
||||
return DefaultSecretKey
|
||||
func GetSecretKey(store VariableStore) (string, error) {
|
||||
if globalConfig.Server.SecretKey != nil {
|
||||
return *globalConfig.Server.SecretKey, nil
|
||||
}
|
||||
return globalVariables.SecretKey
|
||||
|
||||
generatedKey, err := utility.GenerateSecretKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate secret key: %w", err)
|
||||
}
|
||||
|
||||
secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get secret key: %w", err)
|
||||
}
|
||||
return secretKey, nil
|
||||
}
|
||||
|
||||
// SetSecretKey updates the secret key at runtime
|
||||
func SetSecretKey(key string) {
|
||||
variablesMu.Lock()
|
||||
defer variablesMu.Unlock()
|
||||
if globalVariables != nil {
|
||||
globalVariables.SecretKey = key
|
||||
common.Info("Secret key updated at runtime")
|
||||
}
|
||||
}
|
||||
//func SetSecretKey(key string) {
|
||||
// variablesMu.Lock()
|
||||
// defer variablesMu.Unlock()
|
||||
// if globalVariables != nil {
|
||||
// globalVariables.SecretKey = key
|
||||
// common.Info("Secret key updated at runtime")
|
||||
// }
|
||||
//}
|
||||
|
||||
// GetOrCreateKey gets a key from store, or creates it if not exists
|
||||
// - If key exists in store, returns the stored value
|
||||
@ -178,7 +187,7 @@ func RefreshVariables(store VariableStore) error {
|
||||
return err
|
||||
}
|
||||
if secretKey != "" {
|
||||
globalVariables.SecretKey = secretKey
|
||||
//globalVariables.SecretKey = secretKey
|
||||
common.Info("Secret key refreshed from store")
|
||||
}
|
||||
|
||||
@ -244,9 +253,9 @@ func SaveToStorage(store VariableStore) error {
|
||||
}
|
||||
|
||||
// Save SecretKey
|
||||
if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) {
|
||||
return fmt.Errorf("failed to save secret key to store")
|
||||
}
|
||||
//if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) {
|
||||
// return fmt.Errorf("failed to save secret key to store")
|
||||
//}
|
||||
|
||||
common.Info("Variables saved to storage")
|
||||
return nil
|
||||
|
||||
@ -50,7 +50,8 @@ func NewChatService() *ChatService {
|
||||
// ChatWithKBNames chat with knowledge base names
|
||||
type ChatWithKBNames struct {
|
||||
*entity.Chat
|
||||
KBNames []string `json:"kb_names"`
|
||||
KBNames []string `json:"kb_names"`
|
||||
DatasetIDs []string `json:"dataset_ids"`
|
||||
}
|
||||
|
||||
// ListChatsResponse list chats response
|
||||
@ -99,10 +100,11 @@ func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize
|
||||
// Enrich with knowledge base names
|
||||
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
DatasetIDs: datasetIDs,
|
||||
})
|
||||
}
|
||||
|
||||
@ -165,10 +167,11 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
|
||||
// Enrich with knowledge base names
|
||||
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
DatasetIDs: datasetIDs,
|
||||
})
|
||||
}
|
||||
|
||||
@ -178,9 +181,10 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getKBNames gets knowledge base names by IDs
|
||||
func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string {
|
||||
var names []string
|
||||
// getDatasetNamesAndIDs gets knowledge base names by IDs
|
||||
func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) {
|
||||
var names = make([]string, 0, 0)
|
||||
var ids = make([]string, 0, 0)
|
||||
for _, kbID := range kbIDs {
|
||||
kbIDStr, ok := kbID.(string)
|
||||
if !ok {
|
||||
@ -193,9 +197,10 @@ func (s *ChatService) getKBNames(kbIDs entity.JSONSlice) []string {
|
||||
// Only include valid KBs
|
||||
if kb.Status != nil && *kb.Status == "1" {
|
||||
names = append(names, kb.Name)
|
||||
ids = append(ids, kbIDStr)
|
||||
}
|
||||
}
|
||||
return names
|
||||
return names, ids
|
||||
}
|
||||
|
||||
// ParameterConfig parameter configuration in prompt_config
|
||||
@ -485,7 +490,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
|
||||
}
|
||||
|
||||
// Get KB names
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
|
||||
return &SetDialogResponse{
|
||||
Chat: chat,
|
||||
@ -525,7 +530,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo
|
||||
}
|
||||
|
||||
// Get KB names
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
|
||||
return &SetDialogResponse{
|
||||
Chat: chat,
|
||||
@ -679,10 +684,9 @@ func (s *ChatService) GetChat(userID string, chatID string) (*GetChatResponse, e
|
||||
|
||||
// Step 4: Build response with kb_names (same as Python _build_chat_response)
|
||||
// Resolve kb_ids to kb_names
|
||||
kbNames := s.getKBNames(chat.KBIDs)
|
||||
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
|
||||
// Build dataset_ids from kb_ids (same as Python _resolve_kb_names returns ids)
|
||||
var datasetIDs []string
|
||||
for _, kbID := range chat.KBIDs {
|
||||
datasetID, ok := kbID.(string)
|
||||
if !ok {
|
||||
|
||||
@ -221,7 +221,7 @@ type ListChatSessionsResponse struct {
|
||||
}
|
||||
|
||||
// ListChatSessions lists chat sessions for a dialog
|
||||
func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*ListChatSessionsResponse, error) {
|
||||
func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) {
|
||||
// Get user's tenants
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
@ -231,7 +231,8 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*
|
||||
// Check if user is the owner of the dialog
|
||||
isOwner := false
|
||||
for _, tenantID := range tenantIDs {
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, dialogID)
|
||||
var exists bool
|
||||
exists, err = s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -243,7 +244,8 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*
|
||||
|
||||
// Also check with userID as tenant
|
||||
if !isOwner {
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(userID, dialogID)
|
||||
var exists bool
|
||||
exists, err = s.chatSessionDAO.CheckDialogExists(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -251,11 +253,11 @@ func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*
|
||||
}
|
||||
|
||||
if !isOwner {
|
||||
return nil, errors.New("Only owner of dialog authorized for this operation")
|
||||
return nil, errors.New("only owner of dialog authorized for this operation")
|
||||
}
|
||||
|
||||
// List chat sessions
|
||||
sessions, err := s.chatSessionDAO.ListByDialogID(dialogID)
|
||||
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -31,14 +31,20 @@ func NewSystemService() *SystemService {
|
||||
|
||||
// ConfigResponse system configuration response
|
||||
type ConfigResponse struct {
|
||||
RegisterEnabled int `json:"registerEnabled"`
|
||||
RegisterEnabled int `json:"registerEnabled"`
|
||||
DisablePasswordLogin bool `json:"disablePasswordLogin"`
|
||||
}
|
||||
|
||||
// GetConfig get system configuration
|
||||
func (s *SystemService) GetConfig() (*ConfigResponse, error) {
|
||||
cfg := server.GetConfig()
|
||||
registerEnabled := 1
|
||||
if !cfg.Authentication.RegisterEnabled {
|
||||
registerEnabled = 0
|
||||
}
|
||||
return &ConfigResponse{
|
||||
RegisterEnabled: cfg.RegisterEnabled,
|
||||
RegisterEnabled: registerEnabled,
|
||||
DisablePasswordLogin: cfg.Authentication.DisablePasswordLogin,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ import (
|
||||
"fmt"
|
||||
"hash"
|
||||
"os"
|
||||
"ragflow/internal/cache"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/server"
|
||||
@ -104,23 +105,23 @@ type UserResponse struct {
|
||||
// Register user registration
|
||||
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
|
||||
cfg := server.GetConfig()
|
||||
if cfg.RegisterEnabled == 0 {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
|
||||
if !cfg.Authentication.RegisterEnabled {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("user registration is disabled")
|
||||
}
|
||||
|
||||
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
|
||||
if !emailRegex.MatchString(req.Email) {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("invalid email address: %s", req.Email)
|
||||
}
|
||||
|
||||
existUser, _ := s.userDAO.GetByEmail(req.Email)
|
||||
if existUser != nil {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("email: %s has already registered", req.Email)
|
||||
}
|
||||
|
||||
decryptedPassword, err := s.decryptPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password")
|
||||
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
|
||||
}
|
||||
|
||||
var hashedPassword string
|
||||
@ -642,8 +643,10 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error)
|
||||
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
|
||||
func (s *UserService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
||||
// Get secret key from config
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
secretKey, err := server.GetSecretKey(cache.Get())
|
||||
if err != nil {
|
||||
return nil, common.CodeUnauthorized, err
|
||||
}
|
||||
|
||||
// Extract access token from authorization header
|
||||
// Equivalent to: access_token = str(jwt.loads(authorization)) in Python
|
||||
|
||||
@ -40,13 +40,6 @@ class TestAuthorization:
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_fragment in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
|
||||
def test_auth_invalid_version(self, invalid_auth, expected_code, expected_fragment):
|
||||
res = system_version(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_fragment in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
|
||||
def test_auth_invalid_token_list(self, invalid_auth, expected_code, expected_fragment):
|
||||
|
||||
@ -92,6 +92,17 @@ export default defineConfig(({ mode }) => {
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/users/me/models': {
|
||||
target: 'http://127.0.0.1:9380/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'^(/api/v1/auth/login)|^(/api/v1/users/me)|^(/api/v1/system/config)|^(/api/v1/system/version)|^(/api/v1/tenants)|^(/api/v1/chats)|^(/api/v1/searches)|^(/api/v1/files)':
|
||||
{
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api': {
|
||||
target: 'http://127.0.0.1:9380/',
|
||||
changeOrigin: true,
|
||||
|
||||
Reference in New Issue
Block a user