mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-27 01:09:57 +08:00
Update go cli (#13717)
### What problem does this PR solve? Go cli ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -10,8 +10,21 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create CLI instance
|
||||
cliApp, err := cli.NewCLI()
|
||||
// Parse command line arguments (skip program name)
|
||||
args, err := cli.ParseConnectionArgs(os.Args[1:])
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Show help and exit
|
||||
if args.ShowHelp {
|
||||
cli.PrintUsage()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Create CLI instance with parsed arguments
|
||||
cliApp, err := cli.NewCLIWithArgs(args)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create CLI: %v\n", err)
|
||||
os.Exit(1)
|
||||
@ -26,9 +39,18 @@ func main() {
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
// Run CLI
|
||||
if err := cliApp.Run(); err != nil {
|
||||
fmt.Printf("CLI error: %v\n", err)
|
||||
os.Exit(1)
|
||||
// Check if we have a single command to execute
|
||||
if args.Command != "" {
|
||||
// Single command mode
|
||||
if err = cliApp.RunSingleCommand(args.Command); err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
// Interactive mode
|
||||
if err = cliApp.Run(); err != nil {
|
||||
fmt.Printf("CLI error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ mysql:
|
||||
user: 'root'
|
||||
password: 'infini_rag_flow'
|
||||
host: 'localhost'
|
||||
port: 5455
|
||||
port: 3306
|
||||
max_connections: 900
|
||||
stale_timeout: 300
|
||||
max_allowed_packet: 1073741824
|
||||
|
||||
5
go.mod
5
go.mod
@ -1,6 +1,6 @@
|
||||
module ragflow
|
||||
|
||||
go 1.25
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
@ -97,7 +97,8 @@ require (
|
||||
golang.org/x/arch v0.6.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect
|
||||
golang.org/x/net v0.49.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/term v0.41.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
google.golang.org/protobuf v1.32.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@ -228,6 +228,10 @@ golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
|
||||
|
||||
@ -109,7 +109,7 @@ func (h *Handler) Health(c *gin.Context) {
|
||||
|
||||
// Ping ping endpoint
|
||||
func (h *Handler) Ping(c *gin.Context) {
|
||||
successNoData(c, "PONG")
|
||||
successNoData(c, "pong")
|
||||
}
|
||||
|
||||
// Login handle admin login
|
||||
@ -420,15 +420,15 @@ func (h *Handler) GetUserAgents(c *gin.Context) {
|
||||
success(c, agents, "")
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handle get user API keys
|
||||
func (h *Handler) GetUserAPIKeys(c *gin.Context) {
|
||||
// ListUserAPITokens handle get user API keys
|
||||
func (h *Handler) ListUserAPITokens(c *gin.Context) {
|
||||
username := c.Param("username")
|
||||
if username == "" {
|
||||
errorResponse(c, "Username is required", 400)
|
||||
return
|
||||
}
|
||||
|
||||
apiKeys, err := h.service.GetUserAPIKeys(username)
|
||||
apiKeys, err := h.service.ListUserAPITokens(username)
|
||||
if err != nil {
|
||||
errorResponse(c, err.Error(), 500)
|
||||
return
|
||||
@ -437,15 +437,15 @@ func (h *Handler) GetUserAPIKeys(c *gin.Context) {
|
||||
success(c, apiKeys, "Get user API keys")
|
||||
}
|
||||
|
||||
// GenerateUserAPIKey handle generate user API key
|
||||
func (h *Handler) GenerateUserAPIKey(c *gin.Context) {
|
||||
// GenerateUserAPIToken handle generate user API key
|
||||
func (h *Handler) GenerateUserAPIToken(c *gin.Context) {
|
||||
username := c.Param("username")
|
||||
if username == "" {
|
||||
errorResponse(c, "Username is required", 400)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, err := h.service.GenerateUserAPIKey(username)
|
||||
apiKey, err := h.service.GenerateUserAPIToken(username)
|
||||
if err != nil {
|
||||
errorResponse(c, err.Error(), 500)
|
||||
return
|
||||
@ -454,16 +454,16 @@ func (h *Handler) GenerateUserAPIKey(c *gin.Context) {
|
||||
success(c, apiKey, "API key generated successfully")
|
||||
}
|
||||
|
||||
// DeleteUserAPIKey handle delete user API key
|
||||
func (h *Handler) DeleteUserAPIKey(c *gin.Context) {
|
||||
// DeleteUserAPIToken handle delete user API key
|
||||
func (h *Handler) DeleteUserAPIToken(c *gin.Context) {
|
||||
username := c.Param("username")
|
||||
key := c.Param("key")
|
||||
key := c.Param("token")
|
||||
if username == "" || key == "" {
|
||||
errorResponse(c, "Username and key are required", 400)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.DeleteUserAPIKey(username, key); err != nil {
|
||||
if err := h.service.DeleteUserAPIToken(username, key); err != nil {
|
||||
errorResponse(c, err.Error(), 404)
|
||||
return
|
||||
}
|
||||
|
||||
@ -67,9 +67,12 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
protected.GET("/users/:username/agents", r.handler.GetUserAgents)
|
||||
|
||||
// API Keys
|
||||
protected.GET("/users/:username/keys", r.handler.GetUserAPIKeys)
|
||||
protected.POST("/users/:username/keys", r.handler.GenerateUserAPIKey)
|
||||
protected.DELETE("/users/:username/keys/:key", r.handler.DeleteUserAPIKey)
|
||||
protected.GET("/users/:username/keys", r.handler.ListUserAPITokens)
|
||||
protected.GET("/users/:username/tokens", r.handler.ListUserAPITokens)
|
||||
protected.POST("/users/:username/keys", r.handler.GenerateUserAPIToken)
|
||||
protected.POST("/users/:username/tokens", r.handler.GenerateUserAPIToken)
|
||||
protected.DELETE("/users/:username/keys/:token", r.handler.DeleteUserAPIToken)
|
||||
protected.DELETE("/users/:username/tokens/:token", r.handler.DeleteUserAPIToken)
|
||||
|
||||
// Role management
|
||||
protected.GET("/roles", r.handler.ListRoles)
|
||||
|
||||
@ -676,7 +676,7 @@ func (s *Service) DeleteUser(username string) (*DeleteUserResult, error) {
|
||||
var userTenantCount int64
|
||||
tx.Model(&model.UserTenant{}).Where("user_id = ?", user.ID).Count(&userTenantCount)
|
||||
result.UserTenantCount = int(userTenantCount)
|
||||
|
||||
|
||||
// 15. Delete user-tenant relations
|
||||
if delErr := tx.Unscoped().Where("user_id = ?", user.ID).Delete(&model.UserTenant{}); delErr.Error != nil {
|
||||
logger.Warn("failed to delete user-tenant relations", zap.Error(delErr.Error))
|
||||
@ -868,20 +868,20 @@ func (s *Service) GetUserAgents(username string) ([]map[string]interface{}, erro
|
||||
|
||||
// API Key methods
|
||||
|
||||
// GetUserAPIKeys get user API keys
|
||||
func (s *Service) GetUserAPIKeys(username string) ([]map[string]interface{}, error) {
|
||||
// ListUserAPITokens get user API keys
|
||||
func (s *Service) ListUserAPITokens(username string) ([]map[string]interface{}, error) {
|
||||
// TODO: Implement get API keys
|
||||
return []map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// GenerateUserAPIKey generate API key for user
|
||||
func (s *Service) GenerateUserAPIKey(username string) (map[string]interface{}, error) {
|
||||
// GenerateUserAPIToken generate API key for user
|
||||
func (s *Service) GenerateUserAPIToken(username string) (map[string]interface{}, error) {
|
||||
// TODO: Implement generate API key
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// DeleteUserAPIKey delete user API key
|
||||
func (s *Service) DeleteUserAPIKey(username, key string) error {
|
||||
// DeleteUserAPIToken delete user API key
|
||||
func (s *Service) DeleteUserAPIToken(username, key string) error {
|
||||
// TODO: Implement delete API key
|
||||
return nil
|
||||
}
|
||||
|
||||
1101
internal/cli/admin_command.go
Normal file
1101
internal/cli/admin_command.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -34,7 +34,7 @@ type BenchmarkResult struct {
|
||||
}
|
||||
|
||||
// RunBenchmark runs a benchmark with the given concurrency and iterations
|
||||
func (c *RAGFlowClient) RunBenchmark(cmd *Command) error {
|
||||
func (c *RAGFlowClient) RunBenchmark(cmd *Command) (ResponseIf, error) {
|
||||
concurrency, ok := cmd.Params["concurrency"].(int)
|
||||
if !ok {
|
||||
concurrency = 1
|
||||
@ -47,29 +47,26 @@ func (c *RAGFlowClient) RunBenchmark(cmd *Command) error {
|
||||
|
||||
nestedCmd, ok := cmd.Params["command"].(*Command)
|
||||
if !ok {
|
||||
return fmt.Errorf("benchmark command not found")
|
||||
return nil, fmt.Errorf("benchmark command not found")
|
||||
}
|
||||
|
||||
if concurrency < 1 {
|
||||
return fmt.Errorf("concurrency must be greater than 0")
|
||||
return nil, fmt.Errorf("concurrency must be greater than 0")
|
||||
}
|
||||
|
||||
// Add iterations to the nested command
|
||||
nestedCmd.Params["iterations"] = iterations
|
||||
|
||||
if concurrency == 1 {
|
||||
return c.runBenchmarkSingle(concurrency, iterations, nestedCmd)
|
||||
return c.runBenchmarkSingle(iterations, nestedCmd)
|
||||
}
|
||||
return c.runBenchmarkConcurrent(concurrency, iterations, nestedCmd)
|
||||
}
|
||||
|
||||
// runBenchmarkSingle runs benchmark with single concurrency (sequential execution)
|
||||
func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCmd *Command) error {
|
||||
func (c *RAGFlowClient) runBenchmarkSingle(iterations int, nestedCmd *Command) (*BenchmarkResponse, error) {
|
||||
commandType := nestedCmd.Type
|
||||
|
||||
startTime := time.Now()
|
||||
responseList := make([]*Response, 0, iterations)
|
||||
|
||||
// For search_on_datasets, convert dataset names to IDs first
|
||||
if commandType == "search_on_datasets" && iterations > 1 {
|
||||
datasets, _ := nestedCmd.Params["datasets"].(string)
|
||||
@ -79,7 +76,7 @@ func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCm
|
||||
name = strings.TrimSpace(name)
|
||||
id, err := c.getDatasetID(name)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
datasetIDs = append(datasetIDs, id)
|
||||
}
|
||||
@ -87,86 +84,67 @@ func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCm
|
||||
}
|
||||
|
||||
// Check if command supports native benchmark (iterations > 1)
|
||||
supportsNative := false
|
||||
if iterations > 1 {
|
||||
result, err := c.ExecuteCommand(nestedCmd)
|
||||
if err == nil && result != nil {
|
||||
// Command supports benchmark natively
|
||||
supportsNative = true
|
||||
duration, _ := result["duration"].(float64)
|
||||
respList, _ := result["response_list"].([]*Response)
|
||||
responseList = respList
|
||||
// convert result to BenchmarkResponse
|
||||
benchmarkResponse := result.(*BenchmarkResponse)
|
||||
benchmarkResponse.Concurrency = 1
|
||||
return benchmarkResponse, err
|
||||
}
|
||||
|
||||
// Calculate and print results
|
||||
successCount := 0
|
||||
for _, resp := range responseList {
|
||||
if isSuccess(resp, commandType) {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
result, err := c.ExecuteCommand(nestedCmd)
|
||||
if err != nil {
|
||||
fmt.Printf("fail to execute: %s", commandType)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qps := float64(0)
|
||||
if duration > 0 {
|
||||
qps = float64(iterations) / duration
|
||||
}
|
||||
|
||||
fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations)
|
||||
fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n",
|
||||
duration, qps, iterations, successCount, iterations-successCount)
|
||||
return nil
|
||||
var benchmarkResponse BenchmarkResponse
|
||||
switch result.Type() {
|
||||
case "common":
|
||||
commonResponse := result.(*CommonResponse)
|
||||
benchmarkResponse.Code = commonResponse.Code
|
||||
benchmarkResponse.Duration = commonResponse.Duration
|
||||
if commonResponse.Code == 0 {
|
||||
benchmarkResponse.SuccessCount = 1
|
||||
} else {
|
||||
benchmarkResponse.FailureCount = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Manual execution: run iterations times
|
||||
if !supportsNative {
|
||||
// Remove iterations param to avoid native benchmark
|
||||
delete(nestedCmd.Params, "iterations")
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
singleResult, err := c.ExecuteCommand(nestedCmd)
|
||||
if err != nil {
|
||||
// Command failed, add a failed response
|
||||
responseList = append(responseList, &Response{StatusCode: 0})
|
||||
continue
|
||||
}
|
||||
|
||||
// For commands that return a single response (like ping with iterations=1)
|
||||
if singleResult != nil {
|
||||
if respList, ok := singleResult["response_list"].([]*Response); ok {
|
||||
responseList = append(responseList, respList...)
|
||||
}
|
||||
} else {
|
||||
// Command executed successfully but returned no data
|
||||
// Mark as success for now
|
||||
responseList = append(responseList, &Response{StatusCode: 200, Body: []byte("pong")})
|
||||
}
|
||||
case "simple":
|
||||
simpleResponse := result.(*SimpleResponse)
|
||||
benchmarkResponse.Code = simpleResponse.Code
|
||||
benchmarkResponse.Duration = simpleResponse.Duration
|
||||
if simpleResponse.Code == 0 {
|
||||
benchmarkResponse.SuccessCount = 1
|
||||
} else {
|
||||
benchmarkResponse.FailureCount = 1
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime).Seconds()
|
||||
|
||||
successCount := 0
|
||||
for _, resp := range responseList {
|
||||
if isSuccess(resp, commandType) {
|
||||
successCount++
|
||||
case "show":
|
||||
dataResponse := result.(*CommonDataResponse)
|
||||
benchmarkResponse.Code = dataResponse.Code
|
||||
benchmarkResponse.Duration = dataResponse.Duration
|
||||
if dataResponse.Code == 0 {
|
||||
benchmarkResponse.SuccessCount = 1
|
||||
} else {
|
||||
benchmarkResponse.FailureCount = 1
|
||||
}
|
||||
case "data":
|
||||
kvResponse := result.(*KeyValueResponse)
|
||||
benchmarkResponse.Code = kvResponse.Code
|
||||
benchmarkResponse.Duration = kvResponse.Duration
|
||||
if kvResponse.Code == 0 {
|
||||
benchmarkResponse.SuccessCount = 1
|
||||
} else {
|
||||
benchmarkResponse.FailureCount = 1
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported command type: %s", result.Type())
|
||||
}
|
||||
|
||||
qps := float64(0)
|
||||
if duration > 0 {
|
||||
qps = float64(iterations) / duration
|
||||
}
|
||||
|
||||
// Print results
|
||||
fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations)
|
||||
fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n",
|
||||
duration, qps, iterations, successCount, iterations-successCount)
|
||||
|
||||
return nil
|
||||
benchmarkResponse.Concurrency = 1
|
||||
return &benchmarkResponse, nil
|
||||
}
|
||||
|
||||
// runBenchmarkConcurrent runs benchmark with multiple concurrent workers
|
||||
func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nestedCmd *Command) error {
|
||||
func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nestedCmd *Command) (*BenchmarkResponse, error) {
|
||||
results := make([]map[string]interface{}, concurrency)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@ -179,7 +157,7 @@ func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nest
|
||||
name = strings.TrimSpace(name)
|
||||
id, err := c.getDatasetID(name)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
datasetIDs = append(datasetIDs, id)
|
||||
}
|
||||
@ -228,17 +206,15 @@ func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nest
|
||||
}
|
||||
|
||||
totalCommands := iterations * concurrency
|
||||
qps := float64(0)
|
||||
if totalDuration > 0 {
|
||||
qps = float64(totalCommands) / totalDuration
|
||||
}
|
||||
|
||||
// Print results
|
||||
fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations)
|
||||
fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n",
|
||||
totalDuration, qps, totalCommands, successCount, totalCommands-successCount)
|
||||
var benchmarkResponse BenchmarkResponse
|
||||
benchmarkResponse.Duration = totalDuration
|
||||
benchmarkResponse.Code = 0
|
||||
benchmarkResponse.SuccessCount = successCount
|
||||
benchmarkResponse.FailureCount = totalCommands - successCount
|
||||
benchmarkResponse.Concurrency = concurrency
|
||||
|
||||
return nil
|
||||
return &benchmarkResponse, nil
|
||||
}
|
||||
|
||||
// executeBenchmarkSilent executes a command for benchmark without printing output
|
||||
@ -250,7 +226,7 @@ func (c *RAGFlowClient) executeBenchmarkSilent(cmd *Command, iterations int) []*
|
||||
var err error
|
||||
|
||||
switch cmd.Type {
|
||||
case "ping_server":
|
||||
case "ping":
|
||||
resp, err = c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil)
|
||||
case "list_user_datasets":
|
||||
resp, err = c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil)
|
||||
@ -290,7 +266,7 @@ func isSuccess(resp *Response, commandType string) bool {
|
||||
}
|
||||
|
||||
switch commandType {
|
||||
case "ping_server":
|
||||
case "ping":
|
||||
return resp.StatusCode == 200 && string(resp.Body) == "pong"
|
||||
case "list_user_datasets", "list_datasets", "search_on_datasets":
|
||||
// Check status code and JSON response code for dataset commands
|
||||
|
||||
@ -17,15 +17,285 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/peterh/liner"
|
||||
"golang.org/x/term"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigFile represents the rf.yml configuration file structure
|
||||
type ConfigFile struct {
|
||||
Host string `yaml:"host"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
APIToken string `yaml:"api_token"`
|
||||
}
|
||||
|
||||
// ConnectionArgs holds the parsed command line arguments
|
||||
type ConnectionArgs struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
Key string
|
||||
Type string
|
||||
Username string
|
||||
Command string
|
||||
ShowHelp bool
|
||||
}
|
||||
|
||||
// LoadDefaultConfigFile reads the rf.yml file from current directory if it exists
|
||||
func LoadDefaultConfigFile() (*ConfigFile, error) {
|
||||
// Try to read rf.yml from current directory
|
||||
data, err := os.ReadFile("rf.yml")
|
||||
if err != nil {
|
||||
// File doesn't exist, return nil without error
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ConfigFile
|
||||
if err = yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse rf.yml: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// LoadConfigFileFromPath reads a config file from the specified path
|
||||
func LoadConfigFileFromPath(path string) (*ConfigFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %v", path, err)
|
||||
}
|
||||
|
||||
var config ConfigFile
|
||||
if err = yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file %s: %v", path, err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// parseHostPort parses a host:port string and returns host and port
|
||||
func parseHostPort(hostPort string) (string, int, error) {
|
||||
if hostPort == "" {
|
||||
return "", -1, nil
|
||||
}
|
||||
|
||||
// Split host and port
|
||||
parts := strings.Split(hostPort, ":")
|
||||
if len(parts) != 2 {
|
||||
return "", -1, fmt.Errorf("invalid host format, expected host:port, got: %s", hostPort)
|
||||
}
|
||||
|
||||
host := parts[0]
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return "", -1, fmt.Errorf("invalid port number: %s", parts[1])
|
||||
}
|
||||
|
||||
return host, port, nil
|
||||
}
|
||||
|
||||
// ParseConnectionArgs parses command line arguments similar to Python's parse_connection_args
|
||||
func ParseConnectionArgs(args []string) (*ConnectionArgs, error) {
|
||||
// First, scan args to check for help and config file
|
||||
var configFilePath string
|
||||
var hasOtherFlags bool
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
if arg == "--help" {
|
||||
return &ConnectionArgs{ShowHelp: true}, nil
|
||||
} else if arg == "-f" && i+1 < len(args) {
|
||||
configFilePath = args[i+1]
|
||||
i++
|
||||
} else if strings.HasPrefix(arg, "-") && arg != "-f" {
|
||||
// Check if it's a flag (not a command)
|
||||
if arg == "-h" || arg == "-p" || arg == "-w" || arg == "-k" ||
|
||||
arg == "-u" || arg == "-admin" || arg == "-user" {
|
||||
hasOtherFlags = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load config file with priority: -f > rf.yml > none
|
||||
var config *ConfigFile
|
||||
var err error
|
||||
|
||||
if configFilePath != "" {
|
||||
// User specified config file via -f
|
||||
config, err = LoadConfigFileFromPath(configFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Try default rf.yml
|
||||
config, err = LoadDefaultConfigFile()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
if hasOtherFlags {
|
||||
return nil, fmt.Errorf("cannot use command line flags (-h, -p, -w, -k, -u, -admin, -user) when using config file. Please use config file or command line flags, not both")
|
||||
}
|
||||
|
||||
return buildArgsFromConfig(config, args)
|
||||
}
|
||||
// Create a new flag set
|
||||
fs := flag.NewFlagSet("ragflow_cli", flag.ContinueOnError)
|
||||
|
||||
// Define flags
|
||||
host := fs.String("h", "127.0.0.1", "Admin or RAGFlow service host")
|
||||
port := fs.Int("p", -1, "Admin or RAGFlow service port (default: 9381 for admin, 9380 for user)")
|
||||
password := fs.String("w", "", "Superuser password")
|
||||
key := fs.String("k", "", "API key for authentication")
|
||||
_ = fs.String("f", "", "Path to config file (YAML format)") // Already parsed above
|
||||
_ = fs.Bool("admin", false, "Run in admin mode (default)")
|
||||
userMode := fs.Bool("user", false, "Run in user mode")
|
||||
username := fs.String("u", "", "Username (email). In admin mode defaults to admin@ragflow.io, in user mode required")
|
||||
|
||||
// Parse the arguments
|
||||
if err = fs.Parse(args); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse arguments: %v", err)
|
||||
}
|
||||
|
||||
// Otherwise, use command line flags
|
||||
return buildArgsFromFlags(host, port, password, key, userMode, username, fs.Args())
|
||||
}
|
||||
|
||||
// buildArgsFromConfig builds ConnectionArgs from config file
|
||||
func buildArgsFromConfig(config *ConfigFile, remainingArgs []string) (*ConnectionArgs, error) {
|
||||
result := &ConnectionArgs{}
|
||||
|
||||
// Parse host:port from config file
|
||||
if config.Host != "" {
|
||||
host, port, err := parseHostPort(config.Host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid host in config file: %v", err)
|
||||
}
|
||||
result.Host = host
|
||||
result.Port = port
|
||||
} else {
|
||||
result.Host = "127.0.0.1"
|
||||
}
|
||||
|
||||
// Apply auth info from config
|
||||
result.Username = config.User
|
||||
result.Password = config.Password
|
||||
result.Key = config.APIToken
|
||||
|
||||
// Determine mode: if config has auth info, use user mode
|
||||
if config.User != "" || config.APIToken != "" {
|
||||
result.Type = "user"
|
||||
} else {
|
||||
result.Type = "admin"
|
||||
result.Username = "admin@ragflow.io"
|
||||
}
|
||||
|
||||
// Set default port if not specified in config
|
||||
if result.Port == -1 {
|
||||
if result.Type == "admin" {
|
||||
result.Port = 9381
|
||||
} else {
|
||||
result.Port = 9380
|
||||
}
|
||||
}
|
||||
|
||||
// Get command from remaining args (no need for quotes or semicolon)
|
||||
if len(remainingArgs) > 0 {
|
||||
result.Command = strings.Join(remainingArgs, " ") + ";"
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// buildArgsFromFlags builds ConnectionArgs from command line flags
|
||||
func buildArgsFromFlags(host *string, port *int, password *string, key *string, userMode *bool, username *string, remainingArgs []string) (*ConnectionArgs, error) {
|
||||
result := &ConnectionArgs{
|
||||
Host: *host,
|
||||
Port: *port,
|
||||
Password: *password,
|
||||
Key: *key,
|
||||
Username: *username,
|
||||
}
|
||||
|
||||
// Determine mode
|
||||
if *userMode {
|
||||
result.Type = "user"
|
||||
} else {
|
||||
result.Type = "admin"
|
||||
}
|
||||
|
||||
// Set default port based on type if not specified
|
||||
if result.Port == -1 {
|
||||
if result.Type == "admin" {
|
||||
result.Port = 9383
|
||||
} else {
|
||||
result.Port = 9384
|
||||
}
|
||||
}
|
||||
|
||||
// Determine username based on mode
|
||||
if result.Type == "admin" && result.Username == "" {
|
||||
result.Username = "admin@ragflow.io"
|
||||
}
|
||||
|
||||
// Get command from remaining args (no need for quotes or semicolon)
|
||||
if len(remainingArgs) > 0 {
|
||||
result.Command = strings.Join(remainingArgs, " ") + ";"
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// PrintUsage prints the CLI usage information
|
||||
func PrintUsage() {
|
||||
fmt.Println(`RAGFlow CLI Client
|
||||
|
||||
Usage: ragflow_cli [options] [command]
|
||||
|
||||
Options:
|
||||
-h string Admin or RAGFlow service host (default "127.0.0.1")
|
||||
-p int Admin or RAGFlow service port (default 9381 for admin, 9380 for user)
|
||||
-w string Superuser password
|
||||
-k string API key for authentication
|
||||
-f string Path to config file (YAML format)
|
||||
-admin Run in admin mode (default)
|
||||
-user Run in user mode
|
||||
-u string Username (email). In admin mode defaults to admin@ragflow.io
|
||||
--help Show this help message
|
||||
|
||||
Configuration File:
|
||||
The CLI will automatically read rf.yml from the current directory if it exists.
|
||||
Use -f to specify a custom config file path.
|
||||
|
||||
Config file format:
|
||||
host: 127.0.0.1:9380
|
||||
user: your-email@example.com
|
||||
password: your-password
|
||||
api_token: your-api-token
|
||||
|
||||
When using a config file, you cannot use other command line flags except -help.
|
||||
The command line is only for the SQL command.
|
||||
|
||||
Commands:
|
||||
SQL commands like: LOGIN USER 'email'; LIST USERS; etc.
|
||||
`)
|
||||
}
|
||||
|
||||
// HistoryFile returns the path to the history file
|
||||
func HistoryFile() string {
|
||||
return os.Getenv("HOME") + "/" + historyFileName
|
||||
@ -39,26 +309,95 @@ type CLI struct {
|
||||
prompt string
|
||||
running bool
|
||||
line *liner.State
|
||||
args *ConnectionArgs
|
||||
}
|
||||
|
||||
// NewCLI creates a new CLI instance
|
||||
func NewCLI() (*CLI, error) {
|
||||
return NewCLIWithArgs(nil)
|
||||
}
|
||||
|
||||
// NewCLIWithArgs creates a new CLI instance with connection arguments
|
||||
func NewCLIWithArgs(args *ConnectionArgs) (*CLI, error) {
|
||||
// Create liner first
|
||||
line := liner.NewLiner()
|
||||
|
||||
// Determine server type
|
||||
serverType := "user"
|
||||
if args != nil && args.Type != "" {
|
||||
serverType = args.Type
|
||||
}
|
||||
|
||||
// Create client with password prompt using liner
|
||||
client := NewRAGFlowClient("user") // Default to user mode
|
||||
client := NewRAGFlowClient(serverType)
|
||||
client.PasswordPrompt = line.PasswordPrompt
|
||||
|
||||
// Apply connection arguments if provided
|
||||
client.HTTPClient.Host = args.Host
|
||||
client.HTTPClient.Port = args.Port
|
||||
|
||||
// Set prompt based on server type
|
||||
prompt := "RAGFlow> "
|
||||
if serverType == "admin" {
|
||||
prompt = "RAGFlow(admin)> "
|
||||
}
|
||||
|
||||
return &CLI{
|
||||
prompt: "RAGFlow> ",
|
||||
prompt: prompt,
|
||||
client: client,
|
||||
line: line,
|
||||
args: args,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run starts the interactive CLI
|
||||
func (c *CLI) Run() error {
|
||||
if c.args.Type == "admin" {
|
||||
// Allow 3 attempts for password verification
|
||||
maxAttempts := 3
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
var input string
|
||||
var err error
|
||||
|
||||
// Check if terminal supports password masking
|
||||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
input, err = c.line.PasswordPrompt("Please input your password: ")
|
||||
} else {
|
||||
// Terminal doesn't support password masking, use regular prompt
|
||||
fmt.Println("Warning: This terminal does not support secure password input")
|
||||
input, err = c.line.Prompt("Please input your password (will be visible): ")
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading input: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
if attempt < maxAttempts {
|
||||
fmt.Println("Password cannot be empty, please try again")
|
||||
continue
|
||||
}
|
||||
return errors.New("no password provided after 3 attempts")
|
||||
}
|
||||
|
||||
// Set the password for verification
|
||||
c.args.Password = input
|
||||
|
||||
if err = c.VerifyAuth(); err != nil {
|
||||
if attempt < maxAttempts {
|
||||
fmt.Printf("Authentication failed: %v (%d/%d attempts)\n", err, attempt, maxAttempts)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("authentication failed after %d attempts: %v", maxAttempts, err)
|
||||
}
|
||||
|
||||
// Authentication successful
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.running = true
|
||||
|
||||
// Load history from file
|
||||
@ -99,8 +438,8 @@ func (c *CLI) Run() error {
|
||||
c.line.AppendHistory(input)
|
||||
}
|
||||
|
||||
if err := c.execute(input); err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
if err = c.execute(input); err != nil {
|
||||
fmt.Printf("CLI error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,7 +463,11 @@ func (c *CLI) execute(input string) error {
|
||||
}
|
||||
|
||||
// Execute the command using the client
|
||||
_, err = c.client.ExecuteCommand(cmd)
|
||||
var result ResponseIf
|
||||
result, err = c.client.ExecuteCommand(cmd)
|
||||
if result != nil {
|
||||
result.PrintOut()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@ -143,14 +486,12 @@ func (c *CLI) handleMetaCommand(cmd *Command) error {
|
||||
fmt.Print("\033[H\033[2J")
|
||||
case "admin":
|
||||
c.client.ServerType = "admin"
|
||||
c.client.HTTPClient.Port = 9381
|
||||
c.prompt = "RAGFlow(admin)> "
|
||||
fmt.Println("Switched to ADMIN mode (port 9381)")
|
||||
fmt.Println("Switched to ADMIN mode")
|
||||
case "user":
|
||||
c.client.ServerType = "user"
|
||||
c.client.HTTPClient.Port = 9380
|
||||
c.prompt = "RAGFlow> "
|
||||
fmt.Println("Switched to USER mode (port 9380)")
|
||||
fmt.Println("Switched to USER mode")
|
||||
case "host":
|
||||
if len(args) == 0 {
|
||||
fmt.Printf("Current host: %s\n", c.client.HTTPClient.Host)
|
||||
@ -195,7 +536,7 @@ Meta Commands:
|
||||
\q or \quit - Exit CLI
|
||||
\c or \clear - Clear screen
|
||||
|
||||
SQL Commands (User Mode):
|
||||
Commands (User Mode):
|
||||
LOGIN USER 'email'; - Login as user
|
||||
REGISTER USER 'name' AS 'nickname' PASSWORD 'pwd'; - Register new user
|
||||
SHOW VERSION; - Show version info
|
||||
@ -205,9 +546,14 @@ SQL Commands (User Mode):
|
||||
LIST CHATS; - List user chats
|
||||
LIST MODEL PROVIDERS; - List model providers
|
||||
LIST DEFAULT MODELS; - List default models
|
||||
LIST TOKENS; - List API tokens
|
||||
CREATE TOKEN; - Create new API token
|
||||
DROP TOKEN 'token_value'; - Delete an API token
|
||||
SET TOKEN 'token_value'; - Set and validate API token
|
||||
SHOW TOKEN; - Show current API token
|
||||
UNSET TOKEN; - Remove current API token
|
||||
|
||||
SQL Commands (Admin Mode):
|
||||
LOGIN USER 'email'; - Login as admin
|
||||
Commands (Admin Mode):
|
||||
LIST USERS; - List all users
|
||||
SHOW USER 'email'; - Show user details
|
||||
CREATE USER 'email' 'password'; - Create new user
|
||||
@ -254,3 +600,34 @@ func RunInteractive() error {
|
||||
|
||||
return cli.Run()
|
||||
}
|
||||
|
||||
// RunSingleCommand executes a single command and exits
|
||||
func (c *CLI) RunSingleCommand(command string) error {
|
||||
// Execute the command
|
||||
if err := c.execute(command); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyAuth verifies authentication if needed
|
||||
func (c *CLI) VerifyAuth() error {
|
||||
if c.args == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.args.Username == "" {
|
||||
return fmt.Errorf("username is required")
|
||||
}
|
||||
|
||||
if c.args.Password == "" {
|
||||
return fmt.Errorf("password is required")
|
||||
}
|
||||
|
||||
// Create login command with username and password
|
||||
cmd := NewCommand("login_user")
|
||||
cmd.Params["email"] = c.args.Username
|
||||
cmd.Params["password"] = c.args.Password
|
||||
_, err := c.client.ExecuteCommand(cmd)
|
||||
return err
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -31,12 +31,13 @@ type HTTPClient struct {
|
||||
Host string
|
||||
Port int
|
||||
APIVersion string
|
||||
APIKey string
|
||||
APIToken string
|
||||
LoginToken string
|
||||
ConnectTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
VerifySSL bool
|
||||
client *http.Client
|
||||
useAPIToken bool
|
||||
}
|
||||
|
||||
// NewHTTPClient creates a new HTTP client
|
||||
@ -85,8 +86,8 @@ func (c *HTTPClient) Headers(authKind string, extra map[string]string) map[strin
|
||||
headers := make(map[string]string)
|
||||
switch authKind {
|
||||
case "api":
|
||||
if c.APIKey != "" {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIKey)
|
||||
if c.APIToken != "" {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIToken)
|
||||
}
|
||||
case "web", "admin":
|
||||
if c.LoginToken != "" {
|
||||
@ -104,6 +105,7 @@ type Response struct {
|
||||
StatusCode int
|
||||
Body []byte
|
||||
Headers http.Header
|
||||
Duration float64
|
||||
}
|
||||
|
||||
// JSON parses the response body as JSON
|
||||
@ -142,11 +144,14 @@ func (c *HTTPClient) Request(method, path string, useAPIBase bool, authKind stri
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
var resp *http.Response
|
||||
startTime := time.Now()
|
||||
resp, err = c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
duration := time.Since(startTime).Seconds()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@ -157,21 +162,93 @@ func (c *HTTPClient) Request(method, path string, useAPIBase bool, authKind stri
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: respBody,
|
||||
Headers: resp.Header.Clone(),
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Request makes an HTTP request
|
||||
func (c *HTTPClient) RequestWith2URL(method, webPath string, apiPath string, headers map[string]string, jsonBody map[string]interface{}) (*Response, error) {
|
||||
var path string
|
||||
var useAPIBase bool
|
||||
var authKind string
|
||||
if c.useAPIToken {
|
||||
path = apiPath
|
||||
useAPIBase = true
|
||||
authKind = "api"
|
||||
} else {
|
||||
path = webPath
|
||||
useAPIBase = false
|
||||
authKind = "web"
|
||||
}
|
||||
|
||||
url := c.BuildURL(path, useAPIBase)
|
||||
mergedHeaders := c.Headers(authKind, headers)
|
||||
|
||||
var body io.Reader
|
||||
if jsonBody != nil {
|
||||
jsonData, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body = bytes.NewReader(jsonData)
|
||||
if mergedHeaders == nil {
|
||||
mergedHeaders = make(map[string]string)
|
||||
}
|
||||
mergedHeaders["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, v := range mergedHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
startTime := time.Now()
|
||||
resp, err = c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
duration := time.Since(startTime).Seconds()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: respBody,
|
||||
Headers: resp.Header.Clone(),
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequestWithIterations makes multiple HTTP requests for benchmarking
|
||||
// Returns a map with "duration" (total time in seconds) and "response_list"
|
||||
func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}, iterations int) (map[string]interface{}, error) {
|
||||
func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}, iterations int) (*BenchmarkResponse, error) {
|
||||
response := new(BenchmarkResponse)
|
||||
|
||||
if iterations <= 1 {
|
||||
start := time.Now()
|
||||
resp, err := c.Request(method, path, useAPIBase, authKind, headers, jsonBody)
|
||||
totalDuration := time.Since(start).Seconds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"duration": 0.0,
|
||||
"response_list": []*Response{resp},
|
||||
}, nil
|
||||
|
||||
response.Code = resp.StatusCode
|
||||
response.Duration = totalDuration
|
||||
if response.Code == 0 {
|
||||
response.SuccessCount = 1
|
||||
} else {
|
||||
response.FailureCount = 1
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
url := c.BuildURL(path, useAPIBase)
|
||||
@ -232,10 +309,17 @@ func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool,
|
||||
totalDuration += time.Since(start).Seconds()
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"duration": totalDuration,
|
||||
"response_list": responseList,
|
||||
}, nil
|
||||
response.Code = 0
|
||||
response.Duration = totalDuration
|
||||
for _, resp := range responseList {
|
||||
if resp.StatusCode == 200 {
|
||||
response.SuccessCount++
|
||||
} else {
|
||||
response.FailureCount++
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// RequestJSON makes an HTTP request and returns JSON response
|
||||
|
||||
@ -215,6 +215,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenOn, Value: ident}
|
||||
case "SET":
|
||||
return Token{Type: TokenSet, Value: ident}
|
||||
case "UNSET":
|
||||
return Token{Type: TokenUnset, Value: ident}
|
||||
case "RESET":
|
||||
return Token{Type: TokenReset, Value: ident}
|
||||
case "VERSION":
|
||||
@ -287,6 +289,10 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenBenchmark, Value: ident}
|
||||
case "PING":
|
||||
return Token{Type: TokenPing, Value: ident}
|
||||
case "TOKEN":
|
||||
return Token{Type: TokenToken, Value: ident}
|
||||
case "TOKENS":
|
||||
return Token{Type: TokenTokens, Value: ident}
|
||||
default:
|
||||
return Token{Type: TokenIdentifier, Value: ident}
|
||||
}
|
||||
|
||||
@ -102,6 +102,8 @@ func (p *Parser) parseSQLCommand() (*Command, error) {
|
||||
return p.parseRevokeCommand()
|
||||
case TokenSet:
|
||||
return p.parseSetCommand()
|
||||
case TokenUnset:
|
||||
return p.parseUnsetCommand()
|
||||
case TokenReset:
|
||||
return p.parseResetCommand()
|
||||
case TokenGenerate:
|
||||
@ -189,18 +191,20 @@ func (p *Parser) parseLoginUser() (*Command, error) {
|
||||
cmd.Params["email"] = email
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parsePingServer() (*Command, error) {
|
||||
cmd := NewCommand("ping_server")
|
||||
cmd := NewCommand("ping")
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -208,7 +212,6 @@ func (p *Parser) parsePingServer() (*Command, error) {
|
||||
func (p *Parser) parseRegisterCommand() (*Command, error) {
|
||||
cmd := NewCommand("register_user")
|
||||
|
||||
p.nextToken() // consume REGISTER
|
||||
if err := p.expectPeek(TokenUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -245,8 +248,9 @@ func (p *Parser) parseRegisterCommand() (*Command, error) {
|
||||
cmd.Params["password"] = password
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
@ -258,38 +262,51 @@ func (p *Parser) parseListCommand() (*Command, error) {
|
||||
switch p.curToken.Type {
|
||||
case TokenServices:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_services"), nil
|
||||
case TokenUsers:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_users"), nil
|
||||
case TokenRoles:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_roles"), nil
|
||||
case TokenVars:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_variables"), nil
|
||||
case TokenConfigs:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_configs"), nil
|
||||
case TokenTokens:
|
||||
p.nextToken()
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_tokens"), nil
|
||||
case TokenEnvs:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_environments"), nil
|
||||
case TokenDatasets:
|
||||
@ -304,8 +321,9 @@ func (p *Parser) parseListCommand() (*Command, error) {
|
||||
return p.parseListDefaultModels()
|
||||
case TokenChats:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_user_chats"), nil
|
||||
case TokenFiles:
|
||||
@ -334,8 +352,9 @@ func (p *Parser) parseListDatasets() (*Command, error) {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -361,8 +380,9 @@ func (p *Parser) parseListAgents() (*Command, error) {
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -379,12 +399,13 @@ func (p *Parser) parseListKeys() (*Command, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("list_keys")
|
||||
cmd := NewCommand("list_tokens")
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -395,8 +416,9 @@ func (p *Parser) parseListModelProviders() (*Command, error) {
|
||||
return nil, fmt.Errorf("expected PROVIDERS")
|
||||
}
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_user_model_providers"), nil
|
||||
}
|
||||
@ -407,8 +429,9 @@ func (p *Parser) parseListDefaultModels() (*Command, error) {
|
||||
return nil, fmt.Errorf("expected MODELS")
|
||||
}
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("list_user_default_models"), nil
|
||||
}
|
||||
@ -433,8 +456,9 @@ func (p *Parser) parseListFiles() (*Command, error) {
|
||||
cmd.Params["dataset_name"] = datasetName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -445,18 +469,27 @@ func (p *Parser) parseShowCommand() (*Command, error) {
|
||||
switch p.curToken.Type {
|
||||
case TokenVersion:
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("show_version"), nil
|
||||
case TokenToken:
|
||||
p.nextToken()
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("show_token"), nil
|
||||
case TokenCurrent:
|
||||
p.nextToken()
|
||||
if p.curToken.Type != TokenUser {
|
||||
return nil, fmt.Errorf("expected USER after CURRENT")
|
||||
}
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("show_current_user"), nil
|
||||
case TokenUser:
|
||||
@ -485,8 +518,9 @@ func (p *Parser) parseShowUser() (*Command, error) {
|
||||
cmd := NewCommand("show_user_permission")
|
||||
cmd.Params["user_name"] = userName
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -500,8 +534,9 @@ func (p *Parser) parseShowUser() (*Command, error) {
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -517,8 +552,9 @@ func (p *Parser) parseShowRole() (*Command, error) {
|
||||
cmd.Params["role_name"] = roleName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -534,8 +570,9 @@ func (p *Parser) parseShowVariable() (*Command, error) {
|
||||
cmd.Params["var_name"] = varName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -551,8 +588,9 @@ func (p *Parser) parseShowService() (*Command, error) {
|
||||
cmd.Params["number"] = serviceNum
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -571,11 +609,24 @@ func (p *Parser) parseCreateCommand() (*Command, error) {
|
||||
return p.parseCreateDataset()
|
||||
case TokenChat:
|
||||
return p.parseCreateChat()
|
||||
case TokenToken:
|
||||
return p.parseCreateToken()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown CREATE target: %s", p.curToken.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseCreateToken() (*Command, error) {
|
||||
p.nextToken() // consume TOKEN
|
||||
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return NewCommand("create_token"), nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseCreateUser() (*Command, error) {
|
||||
p.nextToken() // consume USER
|
||||
userName, err := p.parseQuotedString()
|
||||
@ -595,8 +646,9 @@ func (p *Parser) parseCreateUser() (*Command, error) {
|
||||
cmd.Params["role"] = "user"
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -622,8 +674,9 @@ func (p *Parser) parseCreateRole() (*Command, error) {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -651,8 +704,9 @@ func (p *Parser) parseCreateModelProvider() (*Command, error) {
|
||||
cmd.Params["provider_key"] = providerKey
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -704,8 +758,9 @@ func (p *Parser) parseCreateDataset() (*Command, error) {
|
||||
return nil, fmt.Errorf("expected PARSER or PIPELINE")
|
||||
}
|
||||
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -721,8 +776,9 @@ func (p *Parser) parseCreateChat() (*Command, error) {
|
||||
cmd.Params["chat_name"] = chatName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -743,11 +799,32 @@ func (p *Parser) parseDropCommand() (*Command, error) {
|
||||
return p.parseDropChat()
|
||||
case TokenKey:
|
||||
return p.parseDropKey()
|
||||
case TokenToken:
|
||||
return p.parseDropToken()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown DROP target: %s", p.curToken.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseDropToken() (*Command, error) {
|
||||
p.nextToken() // consume TOKEN
|
||||
|
||||
tokenValue, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("drop_token")
|
||||
cmd.Params["token"] = tokenValue
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseDropUser() (*Command, error) {
|
||||
p.nextToken() // consume USER
|
||||
userName, err := p.parseQuotedString()
|
||||
@ -759,8 +836,9 @@ func (p *Parser) parseDropUser() (*Command, error) {
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -776,8 +854,9 @@ func (p *Parser) parseDropRole() (*Command, error) {
|
||||
cmd.Params["role_name"] = roleName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -798,8 +877,9 @@ func (p *Parser) parseDropModelProvider() (*Command, error) {
|
||||
cmd.Params["provider_name"] = providerName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -815,8 +895,9 @@ func (p *Parser) parseDropDataset() (*Command, error) {
|
||||
cmd.Params["dataset_name"] = datasetName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -832,15 +913,16 @@ func (p *Parser) parseDropChat() (*Command, error) {
|
||||
cmd.Params["chat_name"] = chatName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseDropKey() (*Command, error) {
|
||||
p.nextToken() // consume KEY
|
||||
key, err := p.parseQuotedString()
|
||||
token, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -856,13 +938,14 @@ func (p *Parser) parseDropKey() (*Command, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("drop_key")
|
||||
cmd.Params["key"] = key
|
||||
cmd := NewCommand("drop_token")
|
||||
cmd.Params["token"] = token
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -905,8 +988,9 @@ func (p *Parser) parseAlterUser() (*Command, error) {
|
||||
cmd.Params["password"] = password
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for SHOW TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -936,8 +1020,9 @@ func (p *Parser) parseAlterUser() (*Command, error) {
|
||||
cmd.Params["role_name"] = roleName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -954,16 +1039,17 @@ func (p *Parser) parseActivateUser() (*Command, error) {
|
||||
status := p.curToken.Value
|
||||
if status != "on" && status != "off" {
|
||||
return nil, fmt.Errorf("expected 'on' or 'off', got %s", p.curToken.Value)
|
||||
}
|
||||
}
|
||||
|
||||
cmd := NewCommand("activate_user")
|
||||
cmd.Params["user_name"] = userName
|
||||
cmd.Params["activate_status"] = status
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
@ -994,8 +1080,9 @@ func (p *Parser) parseAlterRole() (*Command, error) {
|
||||
cmd.Params["description"] = description
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1021,8 +1108,9 @@ func (p *Parser) parseGrantAdmin() (*Command, error) {
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1064,8 +1152,9 @@ func (p *Parser) parseGrantPermission() (*Command, error) {
|
||||
cmd.Params["role_name"] = roleName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1091,8 +1180,9 @@ func (p *Parser) parseRevokeAdmin() (*Command, error) {
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1134,8 +1224,9 @@ func (p *Parser) parseRevokePermission() (*Command, error) {
|
||||
cmd.Params["role_name"] = roleName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1172,6 +1263,9 @@ func (p *Parser) parseSetCommand() (*Command, error) {
|
||||
if p.curToken.Type == TokenDefault {
|
||||
return p.parseSetDefault()
|
||||
}
|
||||
if p.curToken.Type == TokenToken {
|
||||
return p.parseSetToken()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown SET target: %s", p.curToken.Value)
|
||||
}
|
||||
@ -1194,8 +1288,9 @@ func (p *Parser) parseSetVariable() (*Command, error) {
|
||||
cmd.Params["var_value"] = varValue
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1234,9 +1329,29 @@ func (p *Parser) parseSetDefault() (*Command, error) {
|
||||
cmd.Params["model_id"] = modelID
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseSetToken() (*Command, error) {
|
||||
p.nextToken() // consume TOKEN
|
||||
|
||||
tokenValue, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("set_token")
|
||||
cmd.Params["token"] = tokenValue
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
@ -1270,8 +1385,9 @@ func (p *Parser) parseResetCommand() (*Command, error) {
|
||||
cmd.Params["model_type"] = modelType
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1296,12 +1412,13 @@ func (p *Parser) parseGenerateCommand() (*Command, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("generate_key")
|
||||
cmd := NewCommand("generate_token")
|
||||
cmd.Params["user_name"] = userName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1333,8 +1450,9 @@ func (p *Parser) parseImportCommand() (*Command, error) {
|
||||
cmd.Params["dataset_name"] = datasetName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1366,8 +1484,9 @@ func (p *Parser) parseSearchCommand() (*Command, error) {
|
||||
cmd.Params["datasets"] = datasets
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1404,8 +1523,9 @@ func (p *Parser) parseParseDataset() (*Command, error) {
|
||||
cmd.Params["method"] = method
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1436,8 +1556,9 @@ func (p *Parser) parseParseDocs() (*Command, error) {
|
||||
cmd.Params["dataset_name"] = datasetName
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1482,6 +1603,8 @@ func (p *Parser) parseUserStatement() (*Command, error) {
|
||||
return p.parseDropCommand()
|
||||
case TokenSet:
|
||||
return p.parseSetCommand()
|
||||
case TokenUnset:
|
||||
return p.parseUnsetCommand()
|
||||
case TokenReset:
|
||||
return p.parseResetCommand()
|
||||
case TokenList:
|
||||
@ -1513,8 +1636,9 @@ func (p *Parser) parseStartupCommand() (*Command, error) {
|
||||
cmd.Params["number"] = serviceNum
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1535,8 +1659,9 @@ func (p *Parser) parseShutdownCommand() (*Command, error) {
|
||||
cmd.Params["number"] = serviceNum
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
@ -1557,12 +1682,28 @@ func (p *Parser) parseRestartCommand() (*Command, error) {
|
||||
cmd.Params["number"] = serviceNum
|
||||
|
||||
p.nextToken()
|
||||
if err := p.expectSemicolon(); err != nil {
|
||||
return nil, err
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseUnsetCommand() (*Command, error) {
|
||||
p.nextToken() // consume UNSET
|
||||
|
||||
if p.curToken.Type != TokenToken {
|
||||
return nil, fmt.Errorf("expected TOKEN after UNSET")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return NewCommand("unset_token"), nil
|
||||
}
|
||||
|
||||
func tokenTypeToString(t int) string {
|
||||
// Simplified for error messages
|
||||
return fmt.Sprintf("token(%d)", t)
|
||||
|
||||
@ -95,6 +95,9 @@ const (
|
||||
TokenSync
|
||||
TokenBenchmark
|
||||
TokenPing
|
||||
TokenToken
|
||||
TokenTokens
|
||||
TokenUnset
|
||||
|
||||
// Literals
|
||||
TokenIdentifier
|
||||
|
||||
548
internal/cli/user_command.go
Normal file
548
internal/cli/user_command.go
Normal file
@ -0,0 +1,548 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PingServer pings the server to check if it's alive
|
||||
// Returns benchmark result map if iterations > 1, otherwise prints status
|
||||
func (c *RAGFlowClient) PingServer(cmd *Command) (ResponseIf, error) {
|
||||
// Get iterations from command params (for benchmark)
|
||||
iterations := 1
|
||||
if val, ok := cmd.Params["iterations"].(int); ok && val > 1 {
|
||||
iterations = val
|
||||
}
|
||||
|
||||
if iterations > 1 {
|
||||
// Benchmark mode: multiple iterations
|
||||
return c.HTTPClient.RequestWithIterations("GET", "/system/ping", false, "web", nil, nil, iterations)
|
||||
}
|
||||
|
||||
// Single mode
|
||||
resp, err := c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
fmt.Println("Server is down")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to ping: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
result.Message = string(resp.Body)
|
||||
result.Code = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Show server version to show RAGFlow server version
|
||||
// Returns benchmark result map if iterations > 1, otherwise prints status
|
||||
func (c *RAGFlowClient) ShowServerVersion(cmd *Command) (ResponseIf, error) {
|
||||
// Get iterations from command params (for benchmark)
|
||||
iterations := 1
|
||||
if val, ok := cmd.Params["iterations"].(int); ok && val > 1 {
|
||||
iterations = val
|
||||
}
|
||||
|
||||
if iterations > 1 {
|
||||
// Benchmark mode: multiple iterations
|
||||
return c.HTTPClient.RequestWithIterations("GET", "/system/version", false, "web", nil, nil, iterations)
|
||||
}
|
||||
|
||||
// Single mode
|
||||
resp, err := c.HTTPClient.Request("GET", "/system/version", false, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to show version: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to show version: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result KeyValueResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("show version failed: invalid JSON (%w)", err)
|
||||
}
|
||||
result.Key = "version"
|
||||
result.Duration = resp.Duration
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) RegisterUser(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in ADMIN mode")
|
||||
}
|
||||
|
||||
// Check for benchmark iterations
|
||||
var ok bool
|
||||
_, ok = cmd.Params["iterations"].(int)
|
||||
if ok {
|
||||
return nil, fmt.Errorf("failed to register user in benchmark statement")
|
||||
}
|
||||
|
||||
var email string
|
||||
email, ok = cmd.Params["user_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no email")
|
||||
}
|
||||
|
||||
var password string
|
||||
password, ok = cmd.Params["password"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no password")
|
||||
}
|
||||
|
||||
var nickname string
|
||||
nickname, ok = cmd.Params["nickname"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no nickname")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"email": email,
|
||||
"password": password,
|
||||
"nickname": nickname,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", "/user/register", false, "admin", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register user: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to register user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("register user failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ListUserDatasets lists datasets for current user (user mode)
|
||||
// Returns (result_map, error) - result_map is non-nil for benchmark mode
|
||||
func (c *RAGFlowClient) ListUserDatasets(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
// Check for benchmark iterations
|
||||
iterations := 1
|
||||
if val, ok := cmd.Params["iterations"].(int); ok && val > 1 {
|
||||
iterations = val
|
||||
}
|
||||
|
||||
if iterations > 1 {
|
||||
// Benchmark mode - return raw result for benchmark stats
|
||||
return c.HTTPClient.RequestWithIterations("GET", "/datasets", true, "web", nil, nil, iterations)
|
||||
}
|
||||
|
||||
// Normal mode
|
||||
resp, err := c.HTTPClient.Request("GET", "/datasets", true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list datasets: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("list users failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// getDatasetID gets dataset ID by name
|
||||
func (c *RAGFlowClient) getDatasetID(datasetName string) (string, error) {
|
||||
resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list datasets: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("failed to list datasets: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
resJSON, err := resp.JSON()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid JSON response: %w", err)
|
||||
}
|
||||
|
||||
code, ok := resJSON["code"].(float64)
|
||||
if !ok || code != 0 {
|
||||
msg, _ := resJSON["message"].(string)
|
||||
return "", fmt.Errorf("failed to list datasets: %s", msg)
|
||||
}
|
||||
|
||||
data, ok := resJSON["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format")
|
||||
}
|
||||
|
||||
kbs, ok := data["kbs"].([]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: kbs not found")
|
||||
}
|
||||
|
||||
for _, kb := range kbs {
|
||||
if kbMap, ok := kb.(map[string]interface{}); ok {
|
||||
if name, _ := kbMap["name"].(string); name == datasetName {
|
||||
if id, _ := kbMap["id"].(string); id != "" {
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("dataset '%s' not found", datasetName)
|
||||
}
|
||||
|
||||
// formatEmptyArray converts empty arrays to "[]" string
|
||||
func formatEmptyArray(v interface{}) string {
|
||||
if v == nil {
|
||||
return "[]"
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
if len(val) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
case []string:
|
||||
if len(val) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
case []int:
|
||||
if len(val) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// SearchOnDatasets searches for chunks in specified datasets
|
||||
// Returns (result_map, error) - result_map is non-nil for benchmark mode
|
||||
func (c *RAGFlowClient) SearchOnDatasets(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
question, ok := cmd.Params["question"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("question not provided")
|
||||
}
|
||||
|
||||
datasets, ok := cmd.Params["datasets"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("datasets not provided")
|
||||
}
|
||||
|
||||
// Parse dataset names (comma-separated) and convert to IDs
|
||||
datasetNames := strings.Split(datasets, ",")
|
||||
datasetIDs := make([]string, 0, len(datasetNames))
|
||||
for _, name := range datasetNames {
|
||||
name = strings.TrimSpace(name)
|
||||
id, err := c.getDatasetID(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datasetIDs = append(datasetIDs, id)
|
||||
}
|
||||
|
||||
// Check for benchmark iterations
|
||||
iterations := 1
|
||||
if val, ok := cmd.Params["iterations"].(int); ok && val > 1 {
|
||||
iterations = val
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"kb_id": datasetIDs,
|
||||
"question": question,
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.3,
|
||||
}
|
||||
|
||||
if iterations > 1 {
|
||||
// Benchmark mode - return raw result for benchmark stats
|
||||
return c.HTTPClient.RequestWithIterations("POST", "/chunk/retrieval_test", false, "web", nil, payload, iterations)
|
||||
}
|
||||
|
||||
// Normal mode
|
||||
resp, err := c.HTTPClient.Request("POST", "/chunk/retrieval_test", false, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search on datasets: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to search on datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
resJSON, err := resp.JSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON response: %w", err)
|
||||
}
|
||||
|
||||
code, ok := resJSON["code"].(float64)
|
||||
if !ok || code != 0 {
|
||||
msg, _ := resJSON["message"].(string)
|
||||
return nil, fmt.Errorf("failed to search on datasets: %s", msg)
|
||||
}
|
||||
|
||||
data, ok := resJSON["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response format")
|
||||
}
|
||||
|
||||
chunks, ok := data["chunks"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response format: chunks not found")
|
||||
}
|
||||
|
||||
// Convert to slice of maps for printing
|
||||
tableData := make([]map[string]interface{}, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
if chunkMap, ok := chunk.(map[string]interface{}); ok {
|
||||
row := map[string]interface{}{
|
||||
"id": chunkMap["chunk_id"],
|
||||
"content": chunkMap["content_with_weight"],
|
||||
"document_id": chunkMap["doc_id"],
|
||||
"dataset_id": chunkMap["kb_id"],
|
||||
"docnm_kwd": chunkMap["docnm_kwd"],
|
||||
"image_id": chunkMap["image_id"],
|
||||
"similarity": chunkMap["similarity"],
|
||||
"term_similarity": chunkMap["term_similarity"],
|
||||
"vector_similarity": chunkMap["vector_similarity"],
|
||||
}
|
||||
// Add optional fields that may be empty arrays
|
||||
if v, ok := chunkMap["doc_type_kwd"]; ok {
|
||||
row["doc_type_kwd"] = formatEmptyArray(v)
|
||||
}
|
||||
if v, ok := chunkMap["important_kwd"]; ok {
|
||||
row["important_kwd"] = formatEmptyArray(v)
|
||||
}
|
||||
if v, ok := chunkMap["mom_id"]; ok {
|
||||
row["mom_id"] = formatEmptyArray(v)
|
||||
}
|
||||
if v, ok := chunkMap["positions"]; ok {
|
||||
row["positions"] = formatEmptyArray(v)
|
||||
}
|
||||
if v, ok := chunkMap["content_ltks"]; ok {
|
||||
row["content_ltks"] = v
|
||||
}
|
||||
tableData = append(tableData, row)
|
||||
}
|
||||
}
|
||||
|
||||
PrintTableSimple(tableData)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new API token
|
||||
func (c *RAGFlowClient) CreateToken(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", "/tokens", true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to create token: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var createResult CommonDataResponse
|
||||
if err = json.Unmarshal(resp.Body, &createResult); err != nil {
|
||||
return nil, fmt.Errorf("create token failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if createResult.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", createResult.Message)
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
result.Code = 0
|
||||
result.Message = "Token created successfully"
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ListTokens lists all API tokens for the current user
|
||||
func (c *RAGFlowClient) ListTokens(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("GET", "/tokens", true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list tokens: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list tokens: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("list tokens failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// DropToken deletes an API token
|
||||
func (c *RAGFlowClient) DropToken(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
token, ok := cmd.Params["token"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token not provided")
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/tokens/%s", token), true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to drop token: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to drop token: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("drop token failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SetToken sets the API token after validating it
|
||||
func (c *RAGFlowClient) SetToken(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
token, ok := cmd.Params["token"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token not provided")
|
||||
}
|
||||
|
||||
// Save current token to restore if validation fails
|
||||
savedToken := c.HTTPClient.APIToken
|
||||
savedUseAPIToken := c.HTTPClient.useAPIToken
|
||||
|
||||
// Set the new token temporarily for validation
|
||||
c.HTTPClient.APIToken = token
|
||||
c.HTTPClient.useAPIToken = true
|
||||
|
||||
// Validate token by calling list tokens API
|
||||
resp, err := c.HTTPClient.Request("GET", "/tokens", true, "api", nil, nil)
|
||||
if err != nil {
|
||||
// Restore original token on error
|
||||
c.HTTPClient.APIToken = savedToken
|
||||
c.HTTPClient.useAPIToken = savedUseAPIToken
|
||||
return nil, fmt.Errorf("failed to validate token: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
// Restore original token on error
|
||||
c.HTTPClient.APIToken = savedToken
|
||||
c.HTTPClient.useAPIToken = savedUseAPIToken
|
||||
return nil, fmt.Errorf("token validation failed: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
// Restore original token on error
|
||||
c.HTTPClient.APIToken = savedToken
|
||||
c.HTTPClient.useAPIToken = savedUseAPIToken
|
||||
return nil, fmt.Errorf("token validation failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
// Restore original token on error
|
||||
c.HTTPClient.APIToken = savedToken
|
||||
c.HTTPClient.useAPIToken = savedUseAPIToken
|
||||
return nil, fmt.Errorf("token validation failed: %s", result.Message)
|
||||
}
|
||||
|
||||
// Token is valid, keep it set
|
||||
var successResult SimpleResponse
|
||||
successResult.Code = 0
|
||||
successResult.Message = "API token set successfully"
|
||||
successResult.Duration = resp.Duration
|
||||
return &successResult, nil
|
||||
}
|
||||
|
||||
// ShowToken displays the current API token
|
||||
func (c *RAGFlowClient) ShowToken(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
if c.HTTPClient.APIToken == "" {
|
||||
return nil, fmt.Errorf("no API token is currently set")
|
||||
}
|
||||
|
||||
//fmt.Printf("Token: %s\n", c.HTTPClient.APIToken)
|
||||
|
||||
var result CommonResponse
|
||||
result.Code = 0
|
||||
result.Message = ""
|
||||
result.Data = []map[string]interface{}{
|
||||
{
|
||||
"token": c.HTTPClient.APIToken,
|
||||
},
|
||||
}
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// UnsetToken removes the current API token
|
||||
func (c *RAGFlowClient) UnsetToken(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
if c.HTTPClient.APIToken == "" {
|
||||
return nil, fmt.Errorf("no API token is currently set")
|
||||
}
|
||||
|
||||
c.HTTPClient.APIToken = ""
|
||||
c.HTTPClient.useAPIToken = false
|
||||
|
||||
var result SimpleResponse
|
||||
result.Code = 0
|
||||
result.Message = "API token unset successfully"
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
@ -46,6 +46,16 @@ func (dao *APITokenDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetByToken gets API token by access key
|
||||
func (dao *APITokenDAO) GetUserByAPIToken(token string) (*model.APIToken, error) {
|
||||
var apiToken model.APIToken
|
||||
err := DB.Where("token = ?", token).First(&apiToken).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiToken, nil
|
||||
}
|
||||
|
||||
// DeleteByDialogIDs deletes API tokens by dialog IDs (hard delete)
|
||||
func (dao *APITokenDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) {
|
||||
if len(dialogIDs) == 0 {
|
||||
@ -55,6 +65,12 @@ func (dao *APITokenDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) {
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteByTenantIDAndToken deletes a specific API token by tenant ID and token value
|
||||
func (dao *APITokenDAO) DeleteByTenantIDAndToken(tenantID, token string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ? AND token = ?", tenantID, token).Delete(&model.APIToken{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// API4ConversationDAO API for conversation data access object
|
||||
type API4ConversationDAO struct{}
|
||||
|
||||
|
||||
@ -43,6 +43,15 @@ func (dao *UserDAO) GetByID(id uint) (*model.User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (dao *UserDAO) GetByTenantID(tenantID string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := DB.Where("id = ?", tenantID).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByUsername get user by username
|
||||
func (dao *UserDAO) GetByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
|
||||
224
internal/handler/api_token.go
Normal file
224
internal/handler/api_token.go
Normal file
@ -0,0 +1,224 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ListTokens list all API tokens for the current user's tenant
|
||||
// @Summary List API Tokens
|
||||
// @Description List all API tokens for the current user's tenant
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/system/token_list [get]
|
||||
func (h *SystemHandler) ListTokens(c *gin.Context) {
|
||||
// Get current user from context
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userModel, ok := user.(*model.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Invalid user data",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user's tenant with owner role
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
tenants, err := userTenantDAO.GetByUserIDAndRole(userModel.ID, "owner")
|
||||
if err != nil || len(tenants) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Tenant not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Get tokens for the tenant
|
||||
tokens, err := h.systemService.ListAPITokens(tenantID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Failed to list tokens",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": tokens,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateToken creates a new API token for the current user's tenant
|
||||
// @Summary Create API Token
|
||||
// @Description Generate a new API token for the current user's tenant
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param name query string false "Name of the token"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/system/new_token [post]
|
||||
func (h *SystemHandler) CreateToken(c *gin.Context) {
|
||||
// Get current user from context
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userModel, ok := user.(*model.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Invalid user data",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user's tenant with owner role
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
tenants, err := userTenantDAO.GetByUserIDAndRole(userModel.ID, "owner")
|
||||
if err != nil || len(tenants) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Tenant not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Parse request
|
||||
var req service.CreateAPITokenRequest
|
||||
if err := c.ShouldBind(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create token
|
||||
token, err := h.systemService.CreateAPIToken(tenantID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Failed to create token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": token,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteToken deletes an API token
|
||||
// @Summary Delete API Token
|
||||
// @Description Remove an API token for the current user's tenant
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param token path string true "The API token to remove"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/system/token/{token} [delete]
|
||||
func (h *SystemHandler) DeleteToken(c *gin.Context) {
|
||||
// Get current user from context
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userModel, ok := user.(*model.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Invalid user data",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user's tenant with owner role
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
tenants, err := userTenantDAO.GetByUserIDAndRole(userModel.ID, "owner")
|
||||
if err != nil || len(tenants) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Tenant not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Get token from path parameter
|
||||
token := c.Param("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Token is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Delete token
|
||||
if err := h.systemService.DeleteAPIToken(tenantID, token); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "Failed to delete token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
@ -56,12 +56,15 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
|
||||
// Get user by access token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": code,
|
||||
"message": "Invalid access token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
user, code, err = h.userService.GetUserByAPIToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": code,
|
||||
"message": "Invalid access token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if *user.IsSuperuser {
|
||||
|
||||
@ -19,10 +19,9 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// SystemHandler system handler
|
||||
|
||||
@ -36,6 +36,7 @@ type User struct {
|
||||
LoginChannel *string `gorm:"column:login_channel;index" json:"login_channel,omitempty"`
|
||||
Status *string `gorm:"column:status;size:1;default:1;index" json:"status"`
|
||||
IsSuperuser *bool `gorm:"column:is_superuser;index" json:"is_superuser,omitempty"`
|
||||
RoleID int64 `gorm:"column:role_id;index;default:1;not null;" json:"role_id,omitempty"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
|
||||
@ -116,6 +116,11 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// User set tenant info endpoint
|
||||
authorized.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo)
|
||||
|
||||
// System token endpoints (requires authentication)
|
||||
authorized.GET("/v1/system/token_list", r.systemHandler.ListTokens)
|
||||
authorized.POST("/v1/system/new_token", r.systemHandler.CreateToken)
|
||||
authorized.DELETE("/v1/system/token/:token", r.systemHandler.DeleteToken)
|
||||
|
||||
// API v1 route group
|
||||
v1 := authorized.Group("/api/v1")
|
||||
{
|
||||
@ -128,6 +133,13 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// users.GET("/:id", r.userHandler.GetUserByID)
|
||||
//}
|
||||
|
||||
apiTokens := v1.Group("/tokens")
|
||||
{
|
||||
apiTokens.POST("", r.systemHandler.CreateToken)
|
||||
apiTokens.GET("", r.systemHandler.ListTokens)
|
||||
apiTokens.DELETE("/:token", r.systemHandler.DeleteToken)
|
||||
}
|
||||
|
||||
// Document routes
|
||||
documents := v1.Group("/documents")
|
||||
{
|
||||
|
||||
107
internal/service/api_token.go
Normal file
107
internal/service/api_token.go
Normal file
@ -0,0 +1,107 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/utility"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenResponse token response
|
||||
type TokenResponse struct {
|
||||
TenantID string `json:"tenant_id"`
|
||||
Token string `json:"token"`
|
||||
DialogID *string `json:"dialog_id,omitempty"`
|
||||
Source *string `json:"source,omitempty"`
|
||||
Beta *string `json:"beta,omitempty"`
|
||||
CreateTime *int64 `json:"create_time,omitempty"`
|
||||
UpdateTime *int64 `json:"update_time,omitempty"`
|
||||
}
|
||||
|
||||
// ListAPITokens list all API tokens for a tenant
|
||||
func (s *SystemService) ListAPITokens(tenantID string) ([]*TokenResponse, error) {
|
||||
APITokenDAO := dao.NewAPITokenDAO()
|
||||
tokens, err := APITokenDAO.GetByTenantID(tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responses := make([]*TokenResponse, len(tokens))
|
||||
for i, token := range tokens {
|
||||
responses[i] = &TokenResponse{
|
||||
TenantID: token.TenantID,
|
||||
Token: token.Token,
|
||||
DialogID: token.DialogID,
|
||||
Source: token.Source,
|
||||
Beta: token.Beta,
|
||||
CreateTime: token.CreateTime,
|
||||
UpdateTime: token.UpdateTime,
|
||||
}
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
// CreateAPITokenRequest create token request
|
||||
type CreateAPITokenRequest struct {
|
||||
Name string `json:"name" form:"name"`
|
||||
}
|
||||
|
||||
// CreateAPIToken creates a new API token for a tenant
|
||||
func (s *SystemService) CreateAPIToken(tenantID string, req *CreateAPITokenRequest) (*TokenResponse, error) {
|
||||
APITokenDAO := dao.NewAPITokenDAO()
|
||||
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
|
||||
// Generate token and beta values
|
||||
// token: "ragflow-" + secrets.token_urlsafe(32)
|
||||
APIToken := utility.GenerateAPIToken()
|
||||
// beta: generate_confirmation_token().replace("ragflow-", "")[:32]
|
||||
betaAPIKey := utility.GenerateBetaAPIToken(APIToken)
|
||||
|
||||
APITokenData := &model.APIToken{
|
||||
TenantID: tenantID,
|
||||
Token: APIToken,
|
||||
Beta: &betaAPIKey,
|
||||
}
|
||||
APITokenData.CreateDate = &nowDate
|
||||
APITokenData.CreateTime = &now
|
||||
|
||||
if err := APITokenDAO.Create(APITokenData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenResponse{
|
||||
TenantID: APITokenData.TenantID,
|
||||
Token: APITokenData.Token,
|
||||
DialogID: APITokenData.DialogID,
|
||||
Source: APITokenData.Source,
|
||||
Beta: APITokenData.Beta,
|
||||
CreateTime: APITokenData.CreateTime,
|
||||
UpdateTime: APITokenData.UpdateTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteAPIToken deletes an API token by tenant ID and token value
|
||||
func (s *SystemService) DeleteAPIToken(tenantID, token string) error {
|
||||
APITokenDAO := dao.NewAPITokenDAO()
|
||||
_, err := APITokenDAO.DeleteByTenantIDAndToken(tenantID, token)
|
||||
return err
|
||||
}
|
||||
@ -59,7 +59,7 @@ func NewUserService() *UserService {
|
||||
// RegisterRequest registration request
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Password string `json:"password" binding:"required,min=1"`
|
||||
Nickname string `json:"nickname"`
|
||||
}
|
||||
|
||||
@ -122,7 +122,8 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password")
|
||||
}
|
||||
|
||||
hashedPassword, err := s.HashPassword(decryptedPassword)
|
||||
var hashedPassword string
|
||||
hashedPassword, err = s.HashPassword(decryptedPassword)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
@ -227,20 +228,20 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
fileDAO := dao.NewFileDAO()
|
||||
|
||||
if err := s.userDAO.Create(user); err != nil {
|
||||
if err = s.userDAO.Create(user); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
if err := tenantDAO.Create(tenant); err != nil {
|
||||
err := s.userDAO.DeleteByID(userID)
|
||||
if err = tenantDAO.Create(tenant); err != nil {
|
||||
err = s.userDAO.DeleteByID(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err)
|
||||
}
|
||||
|
||||
if err := userTenantDAO.Create(userTenant); err != nil {
|
||||
err := s.userDAO.DeleteByID(userID)
|
||||
if err = userTenantDAO.Create(userTenant); err != nil {
|
||||
err = s.userDAO.DeleteByID(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -251,8 +252,8 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err)
|
||||
}
|
||||
|
||||
if err := fileDAO.Create(rootFile); err != nil {
|
||||
err := s.userDAO.DeleteByID(userID)
|
||||
if err = fileDAO.Create(rootFile); err != nil {
|
||||
err = s.userDAO.DeleteByID(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -922,3 +923,44 @@ func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) er
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByAPIToken gets user by access key from Authorization header
|
||||
// This is used for API token authentication
|
||||
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
|
||||
func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, common.ErrorCode, error) {
|
||||
if authorization == "" {
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("authorization header is empty")
|
||||
}
|
||||
|
||||
// Split authorization header to get the token
|
||||
// Expected format: "Bearer <token>" or "<token>"
|
||||
parts := strings.Split(authorization, " ")
|
||||
var token string
|
||||
if len(parts) == 2 {
|
||||
token = parts[1]
|
||||
} else if len(parts) == 1 {
|
||||
token = parts[0]
|
||||
} else {
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format")
|
||||
}
|
||||
|
||||
// Query API token from database
|
||||
apiTokenDAO := dao.NewAPITokenDAO()
|
||||
userToken, err := apiTokenDAO.GetUserByAPIToken(token)
|
||||
if err != nil {
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token")
|
||||
}
|
||||
|
||||
// Get user by tenant_id from API token
|
||||
user, err := s.userDAO.GetByTenantID(userToken.TenantID)
|
||||
if err != nil {
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("user not found for this access token")
|
||||
}
|
||||
|
||||
// Check if user's access_token is empty
|
||||
if user.AccessToken == nil || *user.AccessToken == "" {
|
||||
return nil, common.CodeUnauthorized, fmt.Errorf("user has empty access_token in database")
|
||||
}
|
||||
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
@ -138,3 +138,29 @@ func GenerateSecretKey() (string, error) {
|
||||
func GenerateToken() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
|
||||
// GenerateAPIToken generates a secure random access key
|
||||
// Equivalent to Python's generate_confirmation_token():
|
||||
// return "ragflow-" + secrets.token_urlsafe(32)
|
||||
func GenerateAPIToken() string {
|
||||
// Generate 32 random bytes
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
// Fallback to UUID if random generation fails
|
||||
return "ragflow-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
// Use URL-safe base64 encoding (same as Python's token_urlsafe)
|
||||
return "ragflow-" + base64.RawURLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// GenerateBetaAPIToken generates a beta access key
|
||||
// Equivalent to Python's: generate_confirmation_token().replace("ragflow-", "")[:32]
|
||||
func GenerateBetaAPIToken(accessKey string) string {
|
||||
// Remove "ragflow-" prefix
|
||||
withoutPrefix := strings.TrimPrefix(accessKey, "ragflow-")
|
||||
// Take first 32 characters
|
||||
if len(withoutPrefix) > 32 {
|
||||
return withoutPrefix[:32]
|
||||
}
|
||||
return withoutPrefix
|
||||
}
|
||||
|
||||
@ -49,6 +49,21 @@ export default defineConfig(({ mode }) => {
|
||||
},
|
||||
},
|
||||
hybrid: {
|
||||
'/v1/system/token_list': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/system/new_token': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/system/token': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/system/config': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
|
||||
Reference in New Issue
Block a user