mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 02:07:49 +08:00
Fix: Enhanced user management functionality and cascading data deletion. (#13594)
### What problem does this PR solve? Fix: Enhanced user management functionality and cascading data deletion. Added tenant and related data initialization functionality during user creation, including tenants, user-tenant relationships, LLM configuration, and root folder. Added cascading deletion logic for user deletion, ensuring that all associated data is cleaned up simultaneously when a user is deleted. Implemented a Werkzeug-compatible password hash algorithm (scrypt) and verification functionality. Added multiple DAO methods to support batch data operations and cascading deletion. Improved user login processing and added token signing functionality. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -26,40 +26,84 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
// CheckWerkzeugPassword verifies a password against a werkzeug password hash
|
||||
// Format: pbkdf2:sha256:iterations$salt$hash
|
||||
// Supports both pbkdf2 and scrypt formats
|
||||
func CheckWerkzeugPassword(password, hashStr string) bool {
|
||||
if strings.HasPrefix(hashStr, "scrypt:") {
|
||||
return checkScryptPassword(password, hashStr)
|
||||
}
|
||||
if strings.HasPrefix(hashStr, "pbkdf2:") {
|
||||
return checkPBKDF2Password(password, hashStr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkScryptPassword verifies password using scrypt format
|
||||
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
|
||||
// IMPORTANT: werkzeug uses the base64-encoded salt string as UTF-8 bytes, NOT the decoded bytes
|
||||
func checkScryptPassword(password, hashStr string) bool {
|
||||
parts := strings.Split(hashStr, "$")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
params := strings.Split(parts[0], ":")
|
||||
if len(params) != 4 || params[0] != "scrypt" {
|
||||
return false
|
||||
}
|
||||
|
||||
n, err := strconv.ParseUint(params[1], 10, 0)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
r, err := strconv.ParseUint(params[2], 10, 0)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
p, err := strconv.ParseUint(params[3], 10, 0)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
saltB64 := parts[1]
|
||||
hashHex := parts[2]
|
||||
|
||||
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT decoded bytes
|
||||
// This is the key difference from standard implementations
|
||||
salt := []byte(saltB64)
|
||||
|
||||
// Decode hash from hex
|
||||
expectedHash, err := hex.DecodeString(hashHex)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
computed, err := scrypt.Key([]byte(password), salt, int(n), int(r), int(p), len(expectedHash))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return constantTimeCompare(expectedHash, computed)
|
||||
}
|
||||
|
||||
// checkPBKDF2Password verifies password using PBKDF2 format
|
||||
// Format: pbkdf2:sha256:iterations$base64(salt)$base64(hash)
|
||||
func checkPBKDF2Password(password, hashStr string) bool {
|
||||
parts := strings.Split(hashStr, "$")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse method (e.g., "pbkdf2:sha256:150000")
|
||||
methodParts := strings.Split(parts[0], ":")
|
||||
if len(methodParts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
if methodParts[0] != "pbkdf2" {
|
||||
return false
|
||||
}
|
||||
|
||||
var hashFunc func() hash.Hash
|
||||
switch methodParts[1] {
|
||||
case "sha256":
|
||||
hashFunc = sha256.New
|
||||
case "sha512":
|
||||
// sha512 not supported in this implementation
|
||||
return false
|
||||
default:
|
||||
if len(methodParts) != 3 || methodParts[0] != "pbkdf2" {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -71,48 +115,58 @@ func CheckWerkzeugPassword(password, hashStr string) bool {
|
||||
salt := parts[1]
|
||||
expectedHash := parts[2]
|
||||
|
||||
// Decode salt from base64
|
||||
saltBytes, err := base64.StdEncoding.DecodeString(salt)
|
||||
if err != nil {
|
||||
// Try hex encoding
|
||||
saltBytes, err = hex.DecodeString(salt)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Generate hash using PBKDF2
|
||||
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc)
|
||||
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, sha256.New)
|
||||
computedHash := base64.StdEncoding.EncodeToString(key)
|
||||
|
||||
return computedHash == expectedHash
|
||||
}
|
||||
|
||||
// IsWerkzeugHash checks if a hash is in werkzeug format
|
||||
func IsWerkzeugHash(hashStr string) bool {
|
||||
return strings.HasPrefix(hashStr, "pbkdf2:")
|
||||
// constantTimeCompare performs constant time comparison
|
||||
func constantTimeCompare(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
var result byte
|
||||
for i := 0; i < len(a); i++ {
|
||||
result |= a[i] ^ b[i]
|
||||
}
|
||||
return result == 0
|
||||
}
|
||||
|
||||
// GenerateWerkzeugPasswordHash generates a werkzeug-compatible password hash
|
||||
func GenerateWerkzeugPasswordHash(password string, iterations int) (string, error) {
|
||||
if iterations == 0 {
|
||||
iterations = 150000
|
||||
}
|
||||
// IsWerkzeugHash checks if a hash is in werkzeug format
|
||||
func IsWerkzeugHash(hashStr string) bool {
|
||||
return strings.HasPrefix(hashStr, "scrypt:") || strings.HasPrefix(hashStr, "pbkdf2:")
|
||||
}
|
||||
|
||||
// Generate random salt
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
// GenerateWerkzeugPasswordHash generates a werkzeug-compatible password hash using scrypt
|
||||
// This matches Python werkzeug's default behavior
|
||||
func GenerateWerkzeugPasswordHash(password string, iterations int) (string, error) {
|
||||
// Generate random bytes (12 bytes will produce 16-char base64 string)
|
||||
randomBytes := make([]byte, 12)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate hash using PBKDF2-SHA256
|
||||
key := pbkdf2.Key([]byte(password), salt, iterations, 32, sha256.New)
|
||||
// Encode to base64 string (this will be 16 characters)
|
||||
saltB64 := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
|
||||
// Format: pbkdf2:sha256:iterations$base64(salt)$base64(hash)
|
||||
saltB64 := base64.StdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.StdEncoding.EncodeToString(key)
|
||||
// Use scrypt with werkzeug default parameters: N=32768, r=8, p=1, keyLen=64
|
||||
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT the decoded bytes
|
||||
hash, err := scrypt.Key([]byte(password), []byte(saltB64), 32768, 8, 1, 64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("pbkdf2:sha256:%d$%s$%s", iterations, saltB64, hashB64), nil
|
||||
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
|
||||
return fmt.Sprintf("scrypt:32768:8:1$%s$%x", saltB64, hash), nil
|
||||
}
|
||||
|
||||
// DecryptPassword decrypts the password using RSA private key
|
||||
|
||||
Reference in New Issue
Block a user