mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-04-22 19:57:47 +08:00
Feat:Using Go to implement user registration logic (#13431)
### What problem does this PR solve? Feat:Using Go to implement user registration logic ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
192
.agents/rules/named.md
Normal file
192
.agents/rules/named.md
Normal file
@ -0,0 +1,192 @@
|
||||
# Go Naming Best Practices
|
||||
|
||||
## 1. Package Naming
|
||||
|
||||
- **All lowercase, no underscores**: `package user`, not `package userService` or `package user_service`
|
||||
- **Short and meaningful**: `package http`, `package json`, `package dao`
|
||||
- **Avoid plurals**: `package user` not `package users`
|
||||
- **Avoid generic names**: Avoid `package util`, `package common`, `package base`
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
package user
|
||||
package handler
|
||||
package service
|
||||
|
||||
// Not recommended
|
||||
package UserService
|
||||
package user_service
|
||||
package utils
|
||||
```
|
||||
|
||||
## 2. File Naming
|
||||
|
||||
- **All lowercase, underscore separated**: `user_handler.go`, `user_service.go`
|
||||
- **Test files**: `user_handler_test.go`
|
||||
- **Platform-specific**: `user_linux.go`, `user_windows.go`
|
||||
|
||||
```
|
||||
user/
|
||||
├── user_handler.go
|
||||
├── user_service.go
|
||||
├── user_dao.go
|
||||
└── user_test.go
|
||||
```
|
||||
|
||||
## 3. Directory Naming
|
||||
|
||||
- **All lowercase, no underscores or hyphens**: `internal/`, `pkg/`, `cmd/`
|
||||
- **Short and descriptive**: `handler/`, `service/`, `dao/`
|
||||
|
||||
```
|
||||
project/
|
||||
├── cmd/ # Main entry point
|
||||
│ └── server_main.go
|
||||
├── internal/ # Private code
|
||||
│ ├── handler/
|
||||
│ ├── service/
|
||||
│ ├── dao/
|
||||
│ ├── model/
|
||||
│ └── middleware/
|
||||
├── pkg/ # Public code
|
||||
└── api/ # API definitions
|
||||
```
|
||||
|
||||
## 4. Interface Naming
|
||||
|
||||
- **Single-method interfaces end with "-er"**: `Reader`, `Writer`, `Handler`
|
||||
- **Verb form**: `Reader`, `Executor`, `Validator`
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
type Reader interface {
|
||||
Read(p []byte) (n int, err error)
|
||||
}
|
||||
|
||||
type UserService interface {
|
||||
Register(req *RegisterRequest) (*User, error)
|
||||
Login(req *LoginRequest) (*User, error)
|
||||
}
|
||||
|
||||
// Not recommended
|
||||
type UserInterface interface {}
|
||||
type IUserService interface {}
|
||||
```
|
||||
|
||||
## 5. Struct Naming
|
||||
|
||||
- **CamelCase**: `UserService`, `UserHandler`
|
||||
- **Avoid redundant prefixes**: `User` not `UserModel`
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
type UserService struct {}
|
||||
type UserHandler struct {}
|
||||
type RegisterRequest struct {}
|
||||
|
||||
// Not recommended
|
||||
type user_service struct {}
|
||||
type SUserService struct {}
|
||||
type UserModel struct {}
|
||||
```
|
||||
|
||||
## 6. Method/Function Naming
|
||||
|
||||
- **CamelCase**
|
||||
- **Start with verb**: `GetUser`, `CreateUser`, `DeleteUser`
|
||||
- **Boolean returns use Is/Has/Can prefix**: `IsValid`, `HasPermission`
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
func (s *UserService) Register(req *RegisterRequest) (*User, error)
|
||||
func (s *UserService) GetUserByID(id uint) (*User, error)
|
||||
func (s *UserService) IsEmailExists(email string) bool
|
||||
|
||||
// Not recommended
|
||||
func (s *UserService) register_user()
|
||||
func (s *UserService) get_user_by_id()
|
||||
func (s *UserService) CheckEmailExists() // Should use Is/Has
|
||||
```
|
||||
|
||||
## 7. Constant Naming
|
||||
|
||||
- **CamelCase**: `const MaxRetryCount = 3`
|
||||
- **Enum constants**: `const StatusActive = "active"`
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
const (
|
||||
StatusActive = "1"
|
||||
StatusInactive = "0"
|
||||
MaxRetryCount = 3
|
||||
)
|
||||
|
||||
// Not recommended
|
||||
const (
|
||||
STATUS_ACTIVE = "1" // Not all uppercase
|
||||
status_active = "1" // Not all lowercase
|
||||
)
|
||||
```
|
||||
|
||||
## 8. Error Variable Naming
|
||||
|
||||
- **Start with "Err"**: `ErrNotFound`, `ErrInvalidInput`
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
ErrInvalidInput = errors.New("invalid input")
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
)
|
||||
```
|
||||
|
||||
## 9. Acronyms Keep Consistent Case
|
||||
|
||||
```go
|
||||
// Recommended
|
||||
type HTTPHandler struct {}
|
||||
var URL string
|
||||
func GetHTTPClient() {}
|
||||
func ParseJSON() {}
|
||||
|
||||
// Not recommended
|
||||
type HttpHandler struct {}
|
||||
var Url string
|
||||
func GetHttpClient() {}
|
||||
```
|
||||
|
||||
## 10. Project Structure Naming
|
||||
|
||||
```
|
||||
project-name/
|
||||
├── cmd/ # Main programs
|
||||
│ └── app_name/
|
||||
│ └── main.go
|
||||
├── internal/ # Private code
|
||||
│ ├── handler/ # HTTP handlers
|
||||
│ ├── service/ # Business logic
|
||||
│ ├── repository/ # Data access
|
||||
│ ├── model/ # Data models
|
||||
│ └── config/ # Configuration
|
||||
├── pkg/ # Public code
|
||||
├── api/ # API definitions
|
||||
├── configs/ # Config files
|
||||
├── scripts/ # Scripts
|
||||
├── docs/ # Documentation
|
||||
├── go.mod
|
||||
└── go.sum
|
||||
```
|
||||
|
||||
## Summary Table
|
||||
|
||||
| Type | Rule | Example |
|
||||
| -------------- | ----------------------------------- | ------------------- |
|
||||
| Package | All lowercase, no underscores | `package user` |
|
||||
| File | All lowercase, underscore separated | `user_service.go` |
|
||||
| Directory | All lowercase, no separators | `internal/handler/` |
|
||||
| Struct | CamelCase, capitalized first letter | `UserService` |
|
||||
| Interface | CamelCase, -er suffix | `Reader`, `Writer` |
|
||||
| Method | CamelCase, verb prefix | `GetUserByID` |
|
||||
| Constant | CamelCase | `MaxRetryCount` |
|
||||
| Error Variable | Err prefix | `ErrNotFound` |
|
||||
6
.agents/skills/go-naming/SKILL.md
Normal file
6
.agents/skills/go-naming/SKILL.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
name: go-naming
|
||||
description: Go naming conventions and best practices. Use this skill when working with Go code and need to name packages, files, directories, structs, interfaces, functions, variables, or constants. Provides comprehensive naming guidelines following Go community standards.
|
||||
---
|
||||
|
||||
Strictly follow the naming conventions in [rules/named.md](rules/named.md)
|
||||
3
go.mod
3
go.mod
@ -61,7 +61,8 @@ require (
|
||||
golang.org/x/arch v0.6.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/term v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
google.golang.org/protobuf v1.32.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@ -156,6 +156,10 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
|
||||
|
||||
40
internal/common/error_code.go
Normal file
40
internal/common/error_code.go
Normal file
@ -0,0 +1,40 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package common
|
||||
|
||||
type ErrorCode int
|
||||
|
||||
const (
|
||||
CodeSuccess ErrorCode = 0
|
||||
CodeNotEffective ErrorCode = 10
|
||||
CodeExceptionError ErrorCode = 100
|
||||
CodeArgumentError ErrorCode = 101
|
||||
CodeDataError ErrorCode = 102
|
||||
CodeOperatingError ErrorCode = 103
|
||||
CodeTimeoutError ErrorCode = 104
|
||||
CodeConnectionError ErrorCode = 105
|
||||
CodeRunning ErrorCode = 106
|
||||
CodeResourceExhausted ErrorCode = 107
|
||||
CodePermissionError ErrorCode = 108
|
||||
CodeAuthenticationError ErrorCode = 109
|
||||
CodeBadRequest ErrorCode = 400
|
||||
CodeUnauthorized ErrorCode = 401
|
||||
CodeForbidden ErrorCode = 403
|
||||
CodeNotFound ErrorCode = 404
|
||||
CodeConflict ErrorCode = 409
|
||||
CodeServerError ErrorCode = 500
|
||||
)
|
||||
@ -18,6 +18,7 @@ package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/server"
|
||||
"time"
|
||||
|
||||
@ -77,9 +78,15 @@ func InitDB() error {
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
// Auto migrate
|
||||
//if err := DB.AutoMigrate(&model.User{}, &model.Document{}); err != nil {
|
||||
// return fmt.Errorf("failed to migrate database: %w", err)
|
||||
//}
|
||||
if err := DB.AutoMigrate(
|
||||
&model.User{},
|
||||
&model.Tenant{},
|
||||
&model.UserTenant{},
|
||||
&model.File{},
|
||||
&model.File2Document{},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Database connected and migrated successfully")
|
||||
return nil
|
||||
|
||||
@ -195,6 +195,11 @@ func (dao *FileDAO) GetAllParentFolders(startID string) ([]*model.File, error) {
|
||||
return parentFolders, nil
|
||||
}
|
||||
|
||||
// Create creates a new file
|
||||
func (dao *FileDAO) Create(file *model.File) error {
|
||||
return DB.Create(file).Error
|
||||
}
|
||||
|
||||
// generateUUID generates a UUID
|
||||
func generateUUID() string {
|
||||
id := uuid.New().String()
|
||||
|
||||
@ -88,3 +88,13 @@ func (dao *TenantDAO) GetByID(id string) (*model.Tenant, error) {
|
||||
}
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
// Create creates a new tenant
|
||||
func (dao *TenantDAO) Create(tenant *model.Tenant) error {
|
||||
return DB.Create(tenant).Error
|
||||
}
|
||||
|
||||
// Delete deletes a tenant by ID (soft delete)
|
||||
func (dao *TenantDAO) Delete(id string) error {
|
||||
return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error
|
||||
}
|
||||
|
||||
@ -101,3 +101,8 @@ func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) {
|
||||
func (dao *UserDAO) Delete(id uint) error {
|
||||
return DB.Delete(&model.User{}, id).Error
|
||||
}
|
||||
|
||||
// DeleteByID delete user by string ID
|
||||
func (dao *UserDAO) DeleteByID(id string) error {
|
||||
return DB.Model(&model.User{}).Where("id = ?", id).Update("status", "0").Error
|
||||
}
|
||||
|
||||
@ -59,11 +59,11 @@ func (h *ChatHandler) ListChats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -112,11 +112,11 @@ func (h *ChatHandler) ListChatsNext(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -196,11 +196,11 @@ func (h *ChatHandler) SetDialog(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -268,11 +268,11 @@ func (h *ChatHandler) RemoveChats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -61,11 +61,11 @@ func (h *ChatSessionHandler) SetChatSession(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -124,11 +124,11 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -190,11 +190,11 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -270,11 +270,11 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -59,11 +59,11 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -58,11 +58,11 @@ func (h *ConnectorHandler) ListConnectors(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -65,11 +65,11 @@ func (h *FileHandler) ListFiles(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -141,11 +141,11 @@ func (h *FileHandler) GetRootFolder(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -189,11 +189,11 @@ func (h *FileHandler) GetParentFolder(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token (for validation)
|
||||
_, err := h.userService.GetUserByToken(token)
|
||||
_, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -246,11 +246,11 @@ func (h *FileHandler) GetAllParentFolders(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token (for validation)
|
||||
_, err := h.userService.GetUserByToken(token)
|
||||
_, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -130,11 +130,11 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -71,10 +71,11 @@ func (h *LLMHandler) GetMyLLMs(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -127,10 +128,11 @@ func (h *LLMHandler) Factories(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
_, err := h.userService.GetUserByToken(token)
|
||||
_, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -207,11 +209,11 @@ func (h *LLMHandler) ListApp(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -65,11 +65,11 @@ func (h *SearchHandler) ListSearchApps(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -58,10 +58,11 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -109,11 +110,11 @@ func (h *TenantHandler) TenantList(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -17,7 +17,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
"strconv"
|
||||
@ -47,31 +49,51 @@ func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||
// @Produce json
|
||||
// @Param request body service.RegisterRequest true "registration info"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/users/register [post]
|
||||
// @Router /v1/user/register [post]
|
||||
func (h *UserHandler) Register(c *gin.Context) {
|
||||
var req service.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.Register(&req)
|
||||
user, code, err := h.userService.Register(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "Failed to generate auth token",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Authorization", authToken)
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "*")
|
||||
c.Header("Access-Control-Allow-Headers", "*")
|
||||
c.Header("Access-Control-Expose-Headers", "Authorization")
|
||||
|
||||
profile := h.userService.GetUserProfile(user)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "registration successful",
|
||||
"data": gin.H{
|
||||
"id": user.ID,
|
||||
"nickname": user.Nickname,
|
||||
"email": user.Email,
|
||||
},
|
||||
"code": common.CodeSuccess,
|
||||
"message": fmt.Sprintf("%s, welcome aboard!", req.Nickname),
|
||||
"data": profile,
|
||||
})
|
||||
}
|
||||
|
||||
@ -87,18 +109,20 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
func (h *UserHandler) Login(c *gin.Context) {
|
||||
var req service.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.Login(&req)
|
||||
user, code, err := h.userService.Login(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -114,7 +138,7 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
c.Header("Access-Control-Expose-Headers", "Authorization")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "Welcome back!",
|
||||
"data": user,
|
||||
})
|
||||
@ -132,18 +156,20 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
func (h *UserHandler) LoginByEmail(c *gin.Context) {
|
||||
var req service.EmailLoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.LoginByEmail(&req)
|
||||
user, code, err := h.userService.LoginByEmail(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -151,21 +177,26 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) {
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
|
||||
// Set Authorization header with access_token
|
||||
if user.AccessToken != nil {
|
||||
c.Header("Authorization", authToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "Failed to generate auth token",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Set CORS headers
|
||||
|
||||
c.Header("Authorization", authToken)
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "*")
|
||||
c.Header("Access-Control-Allow-Headers", "*")
|
||||
c.Header("Access-Control-Expose-Headers", "Authorization")
|
||||
|
||||
profile := h.userService.GetUserProfile(user)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "Welcome back!",
|
||||
"data": user,
|
||||
"data": profile,
|
||||
})
|
||||
}
|
||||
|
||||
@ -182,22 +213,28 @@ func (h *UserHandler) GetUserByID(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid user id",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": "invalid user id",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUserByID(uint(id))
|
||||
user, code, err := h.userService.GetUserByID(uint(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "user not found",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": user,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": user,
|
||||
})
|
||||
}
|
||||
|
||||
@ -222,15 +259,19 @@ func (h *UserHandler) ListUsers(c *gin.Context) {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
users, total, err := h.userService.ListUsers(page, pageSize)
|
||||
users, total, code, err := h.userService.ListUsers(page, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "failed to get users",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
@ -253,34 +294,38 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Missing Authorization header",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Invalid access token",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Logout user
|
||||
if err := h.userService.Logout(user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
code, err = h.userService.Logout(user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"code": common.CodeSuccess,
|
||||
"data": true,
|
||||
"message": "success",
|
||||
})
|
||||
@ -299,19 +344,21 @@ func (h *UserHandler) Info(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Missing Authorization header",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"error": "Invalid access token",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -320,8 +367,9 @@ func (h *UserHandler) Info(c *gin.Context) {
|
||||
profile := h.userService.GetUserProfile(user)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": profile,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": profile,
|
||||
})
|
||||
}
|
||||
|
||||
@ -339,18 +387,21 @@ func (h *UserHandler) Setting(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Missing Authorization header",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid access token",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -358,22 +409,29 @@ func (h *UserHandler) Setting(c *gin.Context) {
|
||||
// Parse request
|
||||
var req service.UpdateSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update user settings
|
||||
if err := h.userService.UpdateUserSettings(user, &req); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
code, err = h.userService.UpdateUserSettings(user, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "settings updated successfully",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
|
||||
@ -391,18 +449,21 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
// Extract token from request
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeUnauthorized,
|
||||
"message": "Missing Authorization header",
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by token
|
||||
user, err := h.userService.GetUserByToken(token)
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid access token",
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -410,22 +471,29 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
// Parse request
|
||||
var req service.ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Change password
|
||||
if err := h.userService.ChangePassword(user, &req); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
code, err = h.userService.ChangePassword(user, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "password changed successfully",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
|
||||
@ -438,10 +506,10 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/user/login/channels [get]
|
||||
func (h *UserHandler) GetLoginChannels(c *gin.Context) {
|
||||
channels, err := h.userService.GetLoginChannels()
|
||||
channels, code, err := h.userService.GetLoginChannels()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": "Load channels failure, error: " + err.Error(),
|
||||
"data": []interface{}{},
|
||||
})
|
||||
@ -449,7 +517,7 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": channels,
|
||||
})
|
||||
|
||||
@ -86,6 +86,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
|
||||
// User login by email endpoint
|
||||
engine.POST("/v1/user/login", r.userHandler.LoginByEmail)
|
||||
engine.POST("/v1/user/register", r.userHandler.Register)
|
||||
// User login channels endpoint
|
||||
engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels)
|
||||
// User logout endpoint
|
||||
|
||||
@ -25,7 +25,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -52,9 +54,8 @@ func NewUserService() *UserService {
|
||||
|
||||
// RegisterRequest registration request
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=50"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Nickname string `json:"nickname"`
|
||||
}
|
||||
|
||||
@ -96,125 +97,220 @@ type UserResponse struct {
|
||||
}
|
||||
|
||||
// Register user registration
|
||||
func (s *UserService) Register(req *RegisterRequest) (*model.User, error) {
|
||||
// Check if email exists
|
||||
func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorCode, error) {
|
||||
cfg := server.GetConfig()
|
||||
if cfg.RegisterEnabled == 0 {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
|
||||
}
|
||||
|
||||
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
|
||||
if !emailRegex.MatchString(req.Email) {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
|
||||
}
|
||||
|
||||
existUser, _ := s.userDAO.GetByEmail(req.Email)
|
||||
if existUser != nil {
|
||||
return nil, errors.New("email already exists")
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
|
||||
}
|
||||
|
||||
// Generate password hash
|
||||
hashedPassword, err := s.HashPassword(req.Password)
|
||||
decryptedPassword, err := s.decryptPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password")
|
||||
}
|
||||
|
||||
// Create user
|
||||
status := "1"
|
||||
user := &model.User{
|
||||
Password: &hashedPassword,
|
||||
Email: req.Email,
|
||||
Nickname: req.Nickname,
|
||||
Status: &status,
|
||||
hashedPassword, err := s.HashPassword(decryptedPassword)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
userID := s.GenerateToken()
|
||||
accessToken := s.GenerateToken()
|
||||
status := "1"
|
||||
loginChannel := "password"
|
||||
isSuperuser := false
|
||||
|
||||
user := &model.User{
|
||||
ID: userID,
|
||||
AccessToken: &accessToken,
|
||||
Email: req.Email,
|
||||
Nickname: req.Nickname,
|
||||
Password: &hashedPassword,
|
||||
Status: &status,
|
||||
IsActive: "1",
|
||||
IsAuthenticated: "1",
|
||||
IsAnonymous: "0",
|
||||
LoginChannel: &loginChannel,
|
||||
IsSuperuser: &isSuperuser,
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
user.CreateTime = now
|
||||
user.UpdateTime = &now
|
||||
now_date := time.Now()
|
||||
user.CreateDate = &now_date
|
||||
user.UpdateDate = &now_date
|
||||
user.LastLoginTime = &now_date
|
||||
|
||||
tenantName := req.Nickname + "'s Kingdom"
|
||||
tenant := &model.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
LLMID: cfg.Server.Mode,
|
||||
EmbDID: cfg.Server.Mode,
|
||||
ASRID: cfg.Server.Mode,
|
||||
Img2TxtID: cfg.Server.Mode,
|
||||
RerankID: cfg.Server.Mode,
|
||||
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
|
||||
}
|
||||
tenant.CreateTime = now
|
||||
tenant.UpdateTime = &now
|
||||
tenant.CreateDate = &now_date
|
||||
tenant.UpdateDate = &now_date
|
||||
|
||||
userTenantID := s.GenerateToken()
|
||||
userTenant := &model.UserTenant{
|
||||
ID: userTenantID,
|
||||
UserID: userID,
|
||||
TenantID: userID,
|
||||
Role: "owner",
|
||||
InvitedBy: userID,
|
||||
Status: &status,
|
||||
}
|
||||
userTenant.CreateTime = now
|
||||
userTenant.UpdateTime = &now
|
||||
userTenant.CreateDate = &now_date
|
||||
userTenant.UpdateDate = &now_date
|
||||
|
||||
fileID := s.GenerateToken()
|
||||
rootFile := &model.File{
|
||||
ID: fileID,
|
||||
ParentID: fileID,
|
||||
TenantID: userID,
|
||||
CreatedBy: userID,
|
||||
Name: "/",
|
||||
Type: "folder",
|
||||
Size: 0,
|
||||
}
|
||||
rootFile.CreateTime = now
|
||||
rootFile.UpdateTime = &now
|
||||
rootFile.CreateDate = &now_date
|
||||
rootFile.UpdateDate = &now_date
|
||||
|
||||
tenantDAO := dao.NewTenantDAO()
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
fileDAO := dao.NewFileDAO()
|
||||
|
||||
if err := s.userDAO.Create(user); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
if err := tenantDAO.Create(tenant); err != nil {
|
||||
s.userDAO.DeleteByID(userID)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err)
|
||||
}
|
||||
|
||||
if err := userTenantDAO.Create(userTenant); err != nil {
|
||||
s.userDAO.DeleteByID(userID)
|
||||
tenantDAO.Delete(userID)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err)
|
||||
}
|
||||
|
||||
if err := fileDAO.Create(rootFile); err != nil {
|
||||
s.userDAO.DeleteByID(userID)
|
||||
tenantDAO.Delete(userID)
|
||||
userTenantDAO.Delete(userTenantID)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err)
|
||||
}
|
||||
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// Login user login
|
||||
func (s *UserService) Login(req *LoginRequest) (*model.User, error) {
|
||||
func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, error) {
|
||||
// Get user by email (using username field as email)
|
||||
user, err := s.userDAO.GetByEmail(req.Username)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid email or password")
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid email or password")
|
||||
}
|
||||
|
||||
// Decrypt password using RSA
|
||||
decryptedPassword, err := s.decryptPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
|
||||
return nil, errors.New("invalid username or password")
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid username or password")
|
||||
}
|
||||
|
||||
// Check user status
|
||||
if user.Status == nil || *user.Status != "1" {
|
||||
return nil, errors.New("user is disabled")
|
||||
return nil, common.CodeForbidden, fmt.Errorf("user is disabled")
|
||||
}
|
||||
|
||||
// Generate new access token
|
||||
token := s.GenerateToken()
|
||||
if err := s.UpdateUserAccessToken(user, token); err != nil {
|
||||
return nil, fmt.Errorf("failed to update access token: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update access token: %w", err)
|
||||
}
|
||||
|
||||
// Update timestamp
|
||||
now := time.Now().Unix()
|
||||
user.UpdateTime = &now
|
||||
if err := s.userDAO.Update(user); err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// LoginByEmail user login by email
|
||||
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, error) {
|
||||
// Check for default admin account
|
||||
// Returns user on success, or error with specific code:
|
||||
// - CodeAuthenticationError (109): Email not registered or password mismatch
|
||||
// - CodeServerError (500): Password decryption failure
|
||||
// - CodeForbidden (403): Account disabled
|
||||
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) {
|
||||
if req.Email == "admin@ragflow.io" {
|
||||
return nil, errors.New("default admin account cannot be used to login normal services")
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services")
|
||||
}
|
||||
|
||||
// Get user by email
|
||||
user, err := s.userDAO.GetByEmail(req.Email)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid email or password")
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("Email: %s is not registered!", req.Email)
|
||||
}
|
||||
|
||||
// Decrypt password using RSA
|
||||
decryptedPassword, err := s.decryptPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("Fail to crypt password")
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
|
||||
return nil, errors.New("invalid email or password")
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("Email and password do not match!")
|
||||
}
|
||||
|
||||
// Check user status
|
||||
if user.Status == nil || *user.Status != "1" {
|
||||
return nil, errors.New("user is disabled")
|
||||
if user.IsActive == "0" {
|
||||
return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!")
|
||||
}
|
||||
|
||||
// Generate new access token
|
||||
token := s.GenerateToken()
|
||||
user.AccessToken = &token
|
||||
|
||||
// Update timestamp
|
||||
now := time.Now().Unix()
|
||||
user.UpdateTime = &now
|
||||
now_date := time.Now()
|
||||
user.UpdateDate = &now_date
|
||||
if err := s.userDAO.Update(user); err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// GetUserByID get user by ID
|
||||
func (s *UserService) GetUserByID(id uint) (*UserResponse, error) {
|
||||
func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, error) {
|
||||
user, err := s.userDAO.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, common.CodeNotFound, err
|
||||
}
|
||||
|
||||
return &UserResponse{
|
||||
@ -223,15 +319,15 @@ func (s *UserService) GetUserByID(id uint) (*UserResponse, error) {
|
||||
Nickname: user.Nickname,
|
||||
Status: user.Status,
|
||||
CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"),
|
||||
}, nil
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// ListUsers list users
|
||||
func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, error) {
|
||||
func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, common.ErrorCode, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
users, total, err := s.userDAO.List(offset, pageSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, common.CodeServerError, err
|
||||
}
|
||||
|
||||
responses := make([]*UserResponse, len(users))
|
||||
@ -245,7 +341,7 @@ func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, err
|
||||
}
|
||||
}
|
||||
|
||||
return responses, total, nil
|
||||
return responses, total, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// HashPassword generate password hash
|
||||
@ -399,7 +495,7 @@ func (s *UserService) GenerateToken() string {
|
||||
// GetUserByToken gets user by authorization header
|
||||
// The token parameter is the authorization header value, which needs to be decrypted
|
||||
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
|
||||
func (s *UserService) GetUserByToken(authorization string) (*model.User, error) {
|
||||
func (s *UserService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) {
|
||||
// Get secret key from config
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
@ -408,16 +504,21 @@ func (s *UserService) GetUserByToken(authorization string) (*model.User, error)
|
||||
// Equivalent to: access_token = str(jwt.loads(authorization)) in Python
|
||||
accessToken, err := utility.ExtractAccessToken(authorization, secretKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid authorization token: %w", err)
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization token: %w", err)
|
||||
}
|
||||
|
||||
// Validate token format (should be at least 32 chars, UUID format)
|
||||
if len(accessToken) < 32 {
|
||||
return nil, errors.New("invalid access token format")
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token format")
|
||||
}
|
||||
|
||||
// Get user by access token
|
||||
return s.userDAO.GetByAccessToken(accessToken)
|
||||
user, err := s.userDAO.GetByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
return nil, common.CodeUnauthorized, err
|
||||
}
|
||||
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// UpdateUserAccessToken updates user's access token
|
||||
@ -426,11 +527,15 @@ func (s *UserService) UpdateUserAccessToken(user *model.User, token string) erro
|
||||
}
|
||||
|
||||
// Logout invalidates user's access token
|
||||
func (s *UserService) Logout(user *model.User) error {
|
||||
func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) {
|
||||
// Invalidate token by setting it to an invalid value
|
||||
// Similar to Python implementation: "INVALID_" + secrets.token_hex(16)
|
||||
invalidToken := "INVALID_" + s.GenerateToken()
|
||||
return s.UpdateUserAccessToken(user, invalidToken)
|
||||
err := s.UpdateUserAccessToken(user, invalidToken)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// GetUserProfile returns user profile information
|
||||
@ -539,7 +644,7 @@ func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} {
|
||||
}
|
||||
|
||||
// UpdateUserSettings updates user settings
|
||||
func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) error {
|
||||
func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) (common.ErrorCode, error) {
|
||||
// Update fields if provided
|
||||
if req.Nickname != nil {
|
||||
user.Nickname = *req.Nickname
|
||||
@ -562,15 +667,18 @@ func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRe
|
||||
}
|
||||
|
||||
// Save updated user
|
||||
return s.userDAO.Update(user)
|
||||
if err := s.userDAO.Update(user); err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// ChangePassword changes user password
|
||||
func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) error {
|
||||
func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) (common.ErrorCode, error) {
|
||||
// If password is provided, verify current password
|
||||
if req.Password != nil {
|
||||
if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) {
|
||||
return errors.New("current password is incorrect")
|
||||
return common.CodeBadRequest, fmt.Errorf("current password is incorrect")
|
||||
}
|
||||
}
|
||||
|
||||
@ -578,13 +686,16 @@ func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordReques
|
||||
if req.NewPassword != nil {
|
||||
hashedPassword, err := s.HashPassword(*req.NewPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash new password: %w", err)
|
||||
return common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err)
|
||||
}
|
||||
user.Password = &hashedPassword
|
||||
}
|
||||
|
||||
// Save updated user
|
||||
return s.userDAO.Update(user)
|
||||
if err := s.userDAO.Update(user); err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// LoginChannel represents a login channel response
|
||||
@ -595,7 +706,7 @@ type LoginChannel struct {
|
||||
}
|
||||
|
||||
// GetLoginChannels gets all supported authentication channels
|
||||
func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) {
|
||||
func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, error) {
|
||||
cfg := server.GetConfig()
|
||||
channels := make([]*LoginChannel, 0)
|
||||
|
||||
@ -617,5 +728,5 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) {
|
||||
})
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
return channels, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user