mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-18 21:30:01 +08:00
### What problem does this PR solve? 1. Fix go server date precision 2. Use API_SCHEME_PROXY to control the web API route ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
929 lines
26 KiB
Go
929 lines
26 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package service
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/sha512"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"os"
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/server"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/pbkdf2"
|
|
"golang.org/x/crypto/scrypt"
|
|
|
|
"ragflow/internal/dao"
|
|
"ragflow/internal/model"
|
|
"ragflow/internal/utility"
|
|
)
|
|
|
|
// UserService user service
|
|
type UserService struct {
|
|
userDAO *dao.UserDAO
|
|
}
|
|
|
|
// NewUserService create user service
|
|
func NewUserService() *UserService {
|
|
return &UserService{
|
|
userDAO: dao.NewUserDAO(),
|
|
}
|
|
}
|
|
|
|
// RegisterRequest registration request
|
|
type RegisterRequest struct {
|
|
Email string `json:"email" binding:"required,email"`
|
|
Password string `json:"password" binding:"required,min=6"`
|
|
Nickname string `json:"nickname"`
|
|
}
|
|
|
|
// LoginRequest login request
|
|
type LoginRequest struct {
|
|
Username string `json:"username" binding:"required"`
|
|
Password string `json:"password" binding:"required"`
|
|
}
|
|
|
|
// EmailLoginRequest email login request
|
|
type EmailLoginRequest struct {
|
|
Email string `json:"email" binding:"required,email"`
|
|
Password string `json:"password" binding:"required"`
|
|
}
|
|
|
|
// UpdateSettingsRequest update user settings request
|
|
type UpdateSettingsRequest struct {
|
|
Nickname *string `json:"nickname,omitempty"`
|
|
Email *string `json:"email,omitempty" binding:"omitempty,email"`
|
|
Avatar *string `json:"avatar,omitempty"`
|
|
Language *string `json:"language,omitempty"`
|
|
ColorSchema *string `json:"color_schema,omitempty"`
|
|
Timezone *string `json:"timezone,omitempty"`
|
|
}
|
|
|
|
// ChangePasswordRequest change password request
|
|
type ChangePasswordRequest struct {
|
|
Password *string `json:"password,omitempty"`
|
|
NewPassword *string `json:"new_password,omitempty"`
|
|
}
|
|
|
|
// UserResponse user response
|
|
type UserResponse struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Nickname string `json:"nickname"`
|
|
Status *string `json:"status"`
|
|
CreatedAt string `json:"created_at"`
|
|
}
|
|
|
|
// Register user registration
|
|
func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorCode, error) {
|
|
cfg := server.GetConfig()
|
|
if cfg.RegisterEnabled == 0 {
|
|
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
|
|
}
|
|
|
|
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
|
|
if !emailRegex.MatchString(req.Email) {
|
|
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
|
|
}
|
|
|
|
existUser, _ := s.userDAO.GetByEmail(req.Email)
|
|
if existUser != nil {
|
|
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
|
|
}
|
|
|
|
decryptedPassword, err := s.decryptPassword(req.Password)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password")
|
|
}
|
|
|
|
hashedPassword, err := s.HashPassword(decryptedPassword)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
|
|
userID := utility.GenerateToken()
|
|
accessToken := utility.GenerateToken()
|
|
status := "1"
|
|
loginChannel := "password"
|
|
isSuperuser := false
|
|
|
|
user := &model.User{
|
|
ID: userID,
|
|
AccessToken: &accessToken,
|
|
Email: req.Email,
|
|
Nickname: req.Nickname,
|
|
Password: &hashedPassword,
|
|
Status: &status,
|
|
IsActive: "1",
|
|
IsAuthenticated: "1",
|
|
IsAnonymous: "0",
|
|
LoginChannel: &loginChannel,
|
|
IsSuperuser: &isSuperuser,
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
user.CreateTime = &now
|
|
user.UpdateTime = &now
|
|
now_date := time.Now().Truncate(time.Second)
|
|
user.CreateDate = &now_date
|
|
user.UpdateDate = &now_date
|
|
user.LastLoginTime = &now_date
|
|
|
|
tenantName := req.Nickname + "'s Kingdom"
|
|
|
|
llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name
|
|
if llmID == "" {
|
|
llmID = ""
|
|
}
|
|
embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name
|
|
if embdID == "" {
|
|
embdID = ""
|
|
}
|
|
asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name
|
|
if asrID == "" {
|
|
asrID = ""
|
|
}
|
|
img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name
|
|
if img2txtID == "" {
|
|
img2txtID = ""
|
|
}
|
|
rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name
|
|
if rerankID == "" {
|
|
rerankID = ""
|
|
}
|
|
|
|
tenant := &model.Tenant{
|
|
ID: userID,
|
|
Name: &tenantName,
|
|
LLMID: llmID,
|
|
EmbdID: embdID,
|
|
ASRID: asrID,
|
|
Img2TxtID: img2txtID,
|
|
RerankID: rerankID,
|
|
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
|
|
Status: &status,
|
|
}
|
|
tenant.CreateTime = &now
|
|
tenant.UpdateTime = &now
|
|
tenant.CreateDate = &now_date
|
|
tenant.UpdateDate = &now_date
|
|
|
|
userTenantID := utility.GenerateToken()
|
|
userTenant := &model.UserTenant{
|
|
ID: userTenantID,
|
|
UserID: userID,
|
|
TenantID: userID,
|
|
Role: "owner",
|
|
InvitedBy: userID,
|
|
Status: &status,
|
|
}
|
|
userTenant.CreateTime = &now
|
|
userTenant.UpdateTime = &now
|
|
userTenant.CreateDate = &now_date
|
|
userTenant.UpdateDate = &now_date
|
|
|
|
fileID := utility.GenerateToken()
|
|
rootFile := &model.File{
|
|
ID: fileID,
|
|
ParentID: fileID,
|
|
TenantID: userID,
|
|
CreatedBy: userID,
|
|
Name: "/",
|
|
Type: "folder",
|
|
Size: 0,
|
|
}
|
|
rootFile.CreateTime = &now
|
|
rootFile.UpdateTime = &now
|
|
rootFile.CreateDate = &now_date
|
|
rootFile.UpdateDate = &now_date
|
|
|
|
tenantDAO := dao.NewTenantDAO()
|
|
userTenantDAO := dao.NewUserTenantDAO()
|
|
fileDAO := dao.NewFileDAO()
|
|
|
|
if err := s.userDAO.Create(user); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
if err := tenantDAO.Create(tenant); err != nil {
|
|
err := s.userDAO.DeleteByID(userID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err)
|
|
}
|
|
|
|
if err := userTenantDAO.Create(userTenant); err != nil {
|
|
err := s.userDAO.DeleteByID(userID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
err = tenantDAO.Delete(userID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err)
|
|
}
|
|
|
|
if err := fileDAO.Create(rootFile); err != nil {
|
|
err := s.userDAO.DeleteByID(userID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
err = tenantDAO.Delete(userID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
err = userTenantDAO.Delete(userTenantID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err)
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// Login user login
|
|
func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, error) {
|
|
// Get user by email (using username field as email)
|
|
user, err := s.userDAO.GetByEmail(req.Username)
|
|
if err != nil {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid email or password")
|
|
}
|
|
|
|
// Decrypt password using RSA
|
|
decryptedPassword, err := s.decryptPassword(req.Password)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to decrypt password: %w", err)
|
|
}
|
|
|
|
// Verify password
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid username or password")
|
|
}
|
|
|
|
if user.Status == nil || *user.Status != "1" {
|
|
return nil, common.CodeForbidden, fmt.Errorf("user is disabled")
|
|
}
|
|
|
|
// Generate new access token
|
|
token := utility.GenerateToken()
|
|
if err := s.UpdateUserAccessToken(user, token); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to update access token: %w", err)
|
|
}
|
|
|
|
// Update timestamp
|
|
now := time.Now().Unix()
|
|
user.UpdateTime = &now
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// LoginByEmail user login by email
|
|
// Returns user on success, or error with specific code:
|
|
// - CodeAuthenticationError (109): Email not registered or password mismatch
|
|
// - CodeServerError (500): Password decryption failure
|
|
// - CodeForbidden (403): Account disabled
|
|
func (s *UserService) LoginByEmail(req *EmailLoginRequest, adminLogin bool) (*model.User, common.ErrorCode, error) {
|
|
if !adminLogin && req.Email == "admin@ragflow.io" {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services")
|
|
}
|
|
|
|
user, err := s.userDAO.GetByEmail(req.Email)
|
|
if err != nil {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("Email: %s is not registered!", req.Email)
|
|
}
|
|
|
|
decryptedPassword, err := s.decryptPassword(req.Password)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("Fail to crypt password")
|
|
}
|
|
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("Email and password do not match!")
|
|
}
|
|
|
|
if user.IsActive == "0" {
|
|
return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!")
|
|
}
|
|
|
|
// Generate new access token
|
|
token := utility.GenerateToken()
|
|
user.AccessToken = &token
|
|
|
|
now := time.Now().Unix()
|
|
user.UpdateTime = &now
|
|
now_date := time.Now().Truncate(time.Second)
|
|
user.UpdateDate = &now_date
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetUserByID get user by ID
|
|
func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, error) {
|
|
user, err := s.userDAO.GetByID(id)
|
|
if err != nil {
|
|
return nil, common.CodeNotFound, err
|
|
}
|
|
|
|
return &UserResponse{
|
|
ID: user.ID,
|
|
Email: user.Email,
|
|
Nickname: user.Nickname,
|
|
Status: user.Status,
|
|
CreatedAt: func() string {
|
|
if user.CreateTime != nil {
|
|
return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05")
|
|
}
|
|
return ""
|
|
}(),
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// ListUsers list users
|
|
func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, common.ErrorCode, error) {
|
|
offset := (page - 1) * pageSize
|
|
users, total, err := s.userDAO.List(offset, pageSize)
|
|
if err != nil {
|
|
return nil, 0, common.CodeServerError, err
|
|
}
|
|
|
|
responses := make([]*UserResponse, len(users))
|
|
for i, user := range users {
|
|
responses[i] = &UserResponse{
|
|
ID: user.ID,
|
|
Email: user.Email,
|
|
Nickname: user.Nickname,
|
|
Status: user.Status,
|
|
CreatedAt: func() string {
|
|
if user.CreateTime != nil {
|
|
return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05")
|
|
}
|
|
return ""
|
|
}(),
|
|
}
|
|
}
|
|
|
|
return responses, total, common.CodeSuccess, nil
|
|
}
|
|
|
|
// HashPassword generate password hash using scrypt (werkzeug compatible)
|
|
// The password should already be base64 encoded (from decrypt process)
|
|
// Werkzeug default format: scrypt:32768:8:1$base64(salt)$hex(hash)
|
|
// IMPORTANT: werkzeug uses the base64-encoded salt string as UTF-8 bytes, NOT the decoded bytes
|
|
func (s *UserService) HashPassword(password string) (string, error) {
|
|
// Generate random bytes (12 bytes will produce 16-char base64 string)
|
|
randomBytes, err := s.generateSalt()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
|
}
|
|
|
|
// Encode to base64 string (this will be 16 characters)
|
|
saltB64 := base64.StdEncoding.EncodeToString(randomBytes)
|
|
|
|
// Use scrypt with werkzeug default parameters: N=32768, r=8, p=1, keyLen=64
|
|
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT the decoded bytes
|
|
hash, err := scrypt.Key([]byte(password), []byte(saltB64), 32768, 8, 1, 64)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to compute scrypt hash: %w", err)
|
|
}
|
|
|
|
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
|
|
return fmt.Sprintf("scrypt:32768:8:1$%s$%x", saltB64, hash), nil
|
|
}
|
|
|
|
// VerifyPassword verify password
|
|
// Supports both werkzeug pbkdf2 format (pbkdf2:sha256:iterations$salt$hash) and scrypt format
|
|
func (s *UserService) VerifyPassword(hashedPassword, password string) bool {
|
|
// Check if it's pbkdf2 format (werkzeug)
|
|
if strings.HasPrefix(hashedPassword, "pbkdf2:") {
|
|
return s.verifyPBKDF2Password(hashedPassword, password)
|
|
}
|
|
|
|
// Check if it's scrypt format
|
|
if strings.HasPrefix(hashedPassword, "scrypt:") {
|
|
return s.verifyScryptPassword(hashedPassword, password)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// verifyPBKDF2Password verifies password using PBKDF2 (werkzeug format)
|
|
// Format: pbkdf2:sha256:iterations$salt$hash
|
|
func (s *UserService) verifyPBKDF2Password(hashedPassword, password string) bool {
|
|
parts := strings.Split(hashedPassword, "$")
|
|
if len(parts) != 3 {
|
|
return false
|
|
}
|
|
|
|
// Parse method (e.g., "pbkdf2:sha256:150000")
|
|
methodParts := strings.Split(parts[0], ":")
|
|
if len(methodParts) != 3 {
|
|
return false
|
|
}
|
|
|
|
if methodParts[0] != "pbkdf2" {
|
|
return false
|
|
}
|
|
|
|
var hashFunc func() hash.Hash
|
|
switch methodParts[1] {
|
|
case "sha256":
|
|
hashFunc = sha256.New
|
|
case "sha512":
|
|
hashFunc = sha512.New
|
|
default:
|
|
return false
|
|
}
|
|
|
|
iterations, err := strconv.Atoi(methodParts[2])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
salt := parts[1]
|
|
expectedHash := parts[2]
|
|
|
|
// Decode salt from base64
|
|
saltBytes, err := base64.StdEncoding.DecodeString(salt)
|
|
if err != nil {
|
|
// Try hex encoding
|
|
saltBytes, err = hex.DecodeString(salt)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Generate hash using PBKDF2
|
|
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc)
|
|
computedHash := base64.StdEncoding.EncodeToString(key)
|
|
|
|
return computedHash == expectedHash
|
|
}
|
|
|
|
// verifyScryptPassword verifies password using scrypt format
|
|
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
|
|
// IMPORTANT: werkzeug uses the base64-encoded salt string as UTF-8 bytes, NOT the decoded bytes
|
|
func (s *UserService) verifyScryptPassword(hashedPassword, password string) bool {
|
|
// Parse hash format: scrypt:n:r:p$base64(salt)$hex(hash)
|
|
parts := strings.Split(hashedPassword, "$")
|
|
if len(parts) != 3 {
|
|
return false
|
|
}
|
|
|
|
params := strings.Split(parts[0], ":")
|
|
if len(params) != 4 || params[0] != "scrypt" {
|
|
return false
|
|
}
|
|
|
|
n, err := strconv.ParseUint(params[1], 10, 0)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
r, err := strconv.ParseUint(params[2], 10, 0)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
p, err := strconv.ParseUint(params[3], 10, 0)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
saltB64 := parts[1]
|
|
hashHex := parts[2]
|
|
|
|
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT decoded bytes
|
|
// This is the key difference from standard implementations
|
|
salt := []byte(saltB64)
|
|
|
|
// Decode expected hash from hex
|
|
expectedHash, err := hex.DecodeString(hashHex)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Compute password hash
|
|
computed, err := scrypt.Key([]byte(password), salt, int(n), int(r), int(p), len(expectedHash))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Constant time comparison
|
|
return s.constantTimeCompare(expectedHash, computed)
|
|
}
|
|
|
|
// generateSalt generates a random 12-byte salt (werkzeug default)
|
|
func (s *UserService) generateSalt() ([]byte, error) {
|
|
salt := make([]byte, 12)
|
|
if _, err := rand.Read(salt); err != nil {
|
|
return nil, fmt.Errorf("failed to generate random salt: %w", err)
|
|
}
|
|
return salt, nil
|
|
}
|
|
|
|
// constantTimeCompare constant time comparison
|
|
func (s *UserService) constantTimeCompare(a, b []byte) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
var result byte
|
|
for i := 0; i < len(a); i++ {
|
|
result |= a[i] ^ b[i]
|
|
}
|
|
|
|
return result == 0
|
|
}
|
|
|
|
// loadPrivateKey loads and decrypts the RSA private key from conf/private.pem
|
|
// nolint:staticcheck // DecryptPEMBlock is deprecated but still works for traditional PEM encryption
|
|
func (s *UserService) loadPrivateKey() (*rsa.PrivateKey, error) {
|
|
// Read private key file
|
|
keyData, err := os.ReadFile("conf/private.pem")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read private key file: %w", err)
|
|
}
|
|
|
|
// Parse PEM block
|
|
block, _ := pem.Decode(keyData)
|
|
if block == nil {
|
|
return nil, errors.New("failed to decode PEM block")
|
|
}
|
|
|
|
// Decrypt the PEM block if it's encrypted
|
|
var privateKey interface{}
|
|
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
|
|
// Decrypt using password "Welcome"
|
|
// Note: DecryptPEMBlock is deprecated but still functional for traditional PEM encryption
|
|
decryptedData, err := x509.DecryptPEMBlock(block, []byte("Welcome"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
|
}
|
|
|
|
// Parse the decrypted key
|
|
privateKey, err = x509.ParsePKCS1PrivateKey(decryptedData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
} else {
|
|
// Not encrypted, parse directly
|
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
}
|
|
|
|
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return nil, errors.New("not an RSA private key")
|
|
}
|
|
|
|
return rsaPrivateKey, nil
|
|
}
|
|
|
|
// decryptPassword decrypts the password using RSA private key
|
|
func (s *UserService) decryptPassword(encryptedPassword string) (string, error) {
|
|
// Try to decode base64
|
|
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
|
if err != nil {
|
|
// If base64 decoding fails, assume it's already a plain password
|
|
return encryptedPassword, nil
|
|
}
|
|
|
|
// Load private key
|
|
privateKey, err := s.loadPrivateKey()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Decrypt using PKCS#1 v1.5
|
|
plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
|
|
if err != nil {
|
|
// If decryption fails, assume it's already a plain password
|
|
return encryptedPassword, nil
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
// GetUserByToken gets user by authorization header
|
|
// The token parameter is the authorization header value, which needs to be decrypted
|
|
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
|
|
func (s *UserService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) {
|
|
// Get secret key from config
|
|
variables := server.GetVariables()
|
|
secretKey := variables.SecretKey
|
|
|
|
// Extract access token from authorization header
|
|
// Equivalent to: access_token = str(jwt.loads(authorization)) in Python
|
|
accessToken, err := utility.ExtractAccessToken(authorization, secretKey)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization token: %w", err)
|
|
}
|
|
|
|
// Validate token format (should be at least 32 chars, UUID format)
|
|
if len(accessToken) < 32 {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token format")
|
|
}
|
|
|
|
// Get user by access token
|
|
user, err := s.userDAO.GetByAccessToken(accessToken)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, err
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// UpdateUserAccessToken updates user's access token
|
|
func (s *UserService) UpdateUserAccessToken(user *model.User, token string) error {
|
|
return s.userDAO.UpdateAccessToken(user, token)
|
|
}
|
|
|
|
// Logout invalidates user's access token
|
|
func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) {
|
|
// Invalidate token by setting it to an invalid value
|
|
// Similar to Python implementation: "INVALID_" + secrets.token_hex(16)
|
|
invalidToken := "INVALID_" + utility.GenerateToken()
|
|
err := s.UpdateUserAccessToken(user, invalidToken)
|
|
if err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetUserProfile returns user profile information
|
|
func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} {
|
|
// Format create time and date (from database fields)
|
|
createTime := user.CreateTime
|
|
createDate := ""
|
|
if user.CreateDate != nil {
|
|
createDate = user.CreateDate.Format("2006-01-02T15:04:05")
|
|
}
|
|
|
|
// Format update time and date (from database fields)
|
|
var updateTime int64
|
|
updateDate := ""
|
|
if user.UpdateTime != nil {
|
|
updateTime = *user.UpdateTime
|
|
}
|
|
if user.UpdateDate != nil {
|
|
updateDate = user.UpdateDate.Format("2006-01-02T15:04:05")
|
|
}
|
|
|
|
// Format last login time
|
|
var lastLoginTime string
|
|
if user.LastLoginTime != nil {
|
|
lastLoginTime = user.LastLoginTime.Format("2006-01-02T15:04:05")
|
|
}
|
|
|
|
// Get access token
|
|
var accessToken string
|
|
if user.AccessToken != nil {
|
|
accessToken = *user.AccessToken
|
|
}
|
|
|
|
// Get avatar
|
|
var avatar interface{}
|
|
if user.Avatar != nil {
|
|
avatar = *user.Avatar
|
|
} else {
|
|
avatar = nil
|
|
}
|
|
|
|
// Get color schema
|
|
colorSchema := "Bright"
|
|
if user.ColorSchema != nil && *user.ColorSchema != "" {
|
|
colorSchema = *user.ColorSchema
|
|
}
|
|
|
|
// Get language
|
|
language := "English"
|
|
if user.Language != nil && *user.Language != "" {
|
|
language = *user.Language
|
|
}
|
|
|
|
// Get timezone
|
|
timezone := "UTC+8\tAsia/Shanghai"
|
|
if user.Timezone != nil && *user.Timezone != "" {
|
|
timezone = *user.Timezone
|
|
}
|
|
|
|
// Get login channel
|
|
loginChannel := "password"
|
|
if user.LoginChannel != nil && *user.LoginChannel != "" {
|
|
loginChannel = *user.LoginChannel
|
|
}
|
|
|
|
// Get password
|
|
var password string
|
|
if user.Password != nil {
|
|
password = *user.Password
|
|
}
|
|
|
|
// Get status
|
|
status := "1"
|
|
if user.Status != nil {
|
|
status = *user.Status
|
|
}
|
|
|
|
// Get is_superuser
|
|
isSuperuser := false
|
|
if user.IsSuperuser != nil {
|
|
isSuperuser = *user.IsSuperuser
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"access_token": accessToken,
|
|
"avatar": avatar,
|
|
"color_schema": colorSchema,
|
|
"create_date": createDate,
|
|
"create_time": createTime,
|
|
"email": user.Email,
|
|
"id": user.ID,
|
|
"is_active": user.IsActive,
|
|
"is_anonymous": user.IsAnonymous,
|
|
"is_authenticated": user.IsAuthenticated,
|
|
"is_superuser": isSuperuser,
|
|
"language": language,
|
|
"last_login_time": lastLoginTime,
|
|
"login_channel": loginChannel,
|
|
"nickname": user.Nickname,
|
|
"password": password,
|
|
"status": status,
|
|
"timezone": timezone,
|
|
"update_date": updateDate,
|
|
"update_time": updateTime,
|
|
}
|
|
}
|
|
|
|
// UpdateUserSettings updates user settings
|
|
func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) (common.ErrorCode, error) {
|
|
// Update fields if provided
|
|
if req.Nickname != nil {
|
|
user.Nickname = *req.Nickname
|
|
}
|
|
if req.Email != nil {
|
|
user.Email = *req.Email
|
|
}
|
|
if req.Avatar != nil {
|
|
// In Go version, avatar might be stored differently
|
|
// For now, just update if field exists
|
|
}
|
|
if req.Language != nil {
|
|
// Store language preference
|
|
}
|
|
if req.ColorSchema != nil {
|
|
// Store color schema preference
|
|
}
|
|
if req.Timezone != nil {
|
|
// Store timezone preference
|
|
}
|
|
|
|
// Save updated user
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// ChangePassword changes user password
|
|
func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) (common.ErrorCode, error) {
|
|
// If password is provided, verify current password
|
|
if req.Password != nil {
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) {
|
|
return common.CodeBadRequest, fmt.Errorf("current password is incorrect")
|
|
}
|
|
}
|
|
|
|
// If new password is provided, update password
|
|
if req.NewPassword != nil {
|
|
hashedPassword, err := s.HashPassword(*req.NewPassword)
|
|
if err != nil {
|
|
return common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err)
|
|
}
|
|
user.Password = &hashedPassword
|
|
}
|
|
|
|
// Save updated user
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// LoginChannel represents a login channel response
|
|
type LoginChannel struct {
|
|
Channel string `json:"channel"`
|
|
DisplayName string `json:"display_name"`
|
|
Icon string `json:"icon"`
|
|
}
|
|
|
|
// GetLoginChannels gets all supported authentication channels
|
|
func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, error) {
|
|
cfg := server.GetConfig()
|
|
channels := make([]*LoginChannel, 0)
|
|
|
|
for channel, oauthCfg := range cfg.OAuth {
|
|
displayName := oauthCfg.DisplayName
|
|
if displayName == "" {
|
|
displayName = strings.Title(channel)
|
|
}
|
|
|
|
icon := oauthCfg.Icon
|
|
if icon == "" {
|
|
icon = "sso"
|
|
}
|
|
|
|
channels = append(channels, &LoginChannel{
|
|
Channel: channel,
|
|
DisplayName: displayName,
|
|
Icon: icon,
|
|
})
|
|
}
|
|
|
|
return channels, common.CodeSuccess, nil
|
|
}
|
|
|
|
// SetTenantInfoRequest represents the request for setting tenant info
|
|
type SetTenantInfoRequest struct {
|
|
TenantID string `json:"tenant_id"`
|
|
ASRID string `json:"asr_id"`
|
|
EmbdID string `json:"embd_id"`
|
|
Img2TxtID string `json:"img2txt_id"`
|
|
LLMID string `json:"llm_id"`
|
|
RerankID string `json:"rerank_id"`
|
|
TTSID string `json:"tts_id"`
|
|
}
|
|
|
|
// SetTenantInfo updates tenant model configuration
|
|
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error {
|
|
tenantDAO := dao.NewTenantDAO()
|
|
|
|
_, err := tenantDAO.GetByID(req.TenantID)
|
|
if err != nil {
|
|
return fmt.Errorf("tenant not found: %w", err)
|
|
}
|
|
|
|
updates := make(map[string]interface{})
|
|
if req.LLMID != "" {
|
|
updates["llm_id"] = req.LLMID
|
|
}
|
|
if req.EmbdID != "" {
|
|
updates["embd_id"] = req.EmbdID
|
|
}
|
|
if req.ASRID != "" {
|
|
updates["asr_id"] = req.ASRID
|
|
}
|
|
if req.Img2TxtID != "" {
|
|
updates["img2txt_id"] = req.Img2TxtID
|
|
}
|
|
if req.RerankID != "" {
|
|
updates["rerank_id"] = req.RerankID
|
|
}
|
|
if req.TTSID != "" {
|
|
updates["tts_id"] = req.TTSID
|
|
}
|
|
|
|
if len(updates) > 0 {
|
|
if err := tenantDAO.Update(req.TenantID, updates); err != nil {
|
|
return fmt.Errorf("failed to update tenant: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|