From b308cd3a02712bdeae2ee4195f42679377f4e3d6 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Tue, 24 Mar 2026 20:08:36 +0800 Subject: [PATCH] Update go cli (#13717) ### What problem does this PR solve? Go cli ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- cmd/ragflow_cli.go | 34 +- conf/service_conf.yaml | 2 +- go.mod | 5 +- go.sum | 4 + internal/admin/handler.go | 22 +- internal/admin/router.go | 9 +- internal/admin/service.go | 14 +- internal/cli/admin_command.go | 1101 +++++++++++++++++++++++ internal/cli/benchmark.go | 156 ++-- internal/cli/cli.go | 401 ++++++++- internal/cli/client.go | 922 +++++-------------- internal/cli/http_client.go | 110 ++- internal/cli/lexer.go | 6 + internal/cli/parser.go | 369 +++++--- internal/cli/types.go | 3 + internal/cli/user_command.go | 548 +++++++++++ internal/dao/api_token.go | 16 + internal/dao/user.go | 9 + internal/handler/api_token.go | 224 +++++ internal/handler/auth.go | 15 +- internal/handler/system.go | 3 +- internal/model/{api.go => api_token.go} | 0 internal/model/user.go | 1 + internal/router/router.go | 12 + internal/service/api_token.go | 107 +++ internal/service/user.go | 60 +- internal/utility/token.go | 26 + web/vite.config.ts | 15 + 28 files changed, 3239 insertions(+), 955 deletions(-) create mode 100644 internal/cli/admin_command.go create mode 100644 internal/cli/user_command.go create mode 100644 internal/handler/api_token.go rename internal/model/{api.go => api_token.go} (100%) create mode 100644 internal/service/api_token.go diff --git a/cmd/ragflow_cli.go b/cmd/ragflow_cli.go index 7af88e3ac..374f2df3e 100644 --- a/cmd/ragflow_cli.go +++ b/cmd/ragflow_cli.go @@ -10,8 +10,21 @@ import ( ) func main() { - // Create CLI instance - cliApp, err := cli.NewCLI() + // Parse command line arguments (skip program name) + args, err := cli.ParseConnectionArgs(os.Args[1:]) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + // Show help and exit + if args.ShowHelp { + cli.PrintUsage() + os.Exit(0) + } + + // Create CLI instance with parsed arguments + cliApp, err := cli.NewCLIWithArgs(args) if err != nil { fmt.Printf("Failed to create CLI: %v\n", err) os.Exit(1) @@ -26,9 +39,18 @@ func main() { os.Exit(0) }() - // Run CLI - if err := cliApp.Run(); err != nil { - fmt.Printf("CLI error: %v\n", err) - os.Exit(1) + // Check if we have a single command to execute + if args.Command != "" { + // Single command mode + if err = cliApp.RunSingleCommand(args.Command); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + } else { + // Interactive mode + if err = cliApp.Run(); err != nil { + fmt.Printf("CLI error: %v\n", err) + os.Exit(1) + } } } diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 6029f514f..d024f1719 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -9,7 +9,7 @@ mysql: user: 'root' password: 'infini_rag_flow' host: 'localhost' - port: 5455 + port: 3306 max_connections: 900 stale_timeout: 300 max_allowed_packet: 1073741824 diff --git a/go.mod b/go.mod index 8959f065a..1f28e819a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module ragflow -go 1.25 +go 1.25.0 require ( github.com/aws/aws-sdk-go-v2 v1.41.3 @@ -97,7 +97,8 @@ require ( golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect golang.org/x/net v0.49.0 // indirect - golang.org/x/sys v0.41.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/term v0.41.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/protobuf v1.32.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index dd3937594..b2d5571fa 100644 --- a/go.sum +++ b/go.sum @@ -228,6 +228,10 @@ golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 7bdf99f63..c90e9ea6d 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -109,7 +109,7 @@ func (h *Handler) Health(c *gin.Context) { // Ping ping endpoint func (h *Handler) Ping(c *gin.Context) { - successNoData(c, "PONG") + successNoData(c, "pong") } // Login handle admin login @@ -420,15 +420,15 @@ func (h *Handler) GetUserAgents(c *gin.Context) { success(c, agents, "") } -// GetUserAPIKeys handle get user API keys -func (h *Handler) GetUserAPIKeys(c *gin.Context) { +// ListUserAPITokens handle get user API keys +func (h *Handler) ListUserAPITokens(c *gin.Context) { username := c.Param("username") if username == "" { errorResponse(c, "Username is required", 400) return } - apiKeys, err := h.service.GetUserAPIKeys(username) + apiKeys, err := h.service.ListUserAPITokens(username) if err != nil { errorResponse(c, err.Error(), 500) return @@ -437,15 +437,15 @@ func (h *Handler) GetUserAPIKeys(c *gin.Context) { success(c, apiKeys, "Get user API keys") } -// GenerateUserAPIKey handle generate user API key -func (h *Handler) GenerateUserAPIKey(c *gin.Context) { +// GenerateUserAPIToken handle generate user API key +func (h *Handler) GenerateUserAPIToken(c *gin.Context) { username := c.Param("username") if username == "" { errorResponse(c, "Username is required", 400) return } - apiKey, err := h.service.GenerateUserAPIKey(username) + apiKey, err := h.service.GenerateUserAPIToken(username) if err != nil { errorResponse(c, err.Error(), 500) return @@ -454,16 +454,16 @@ func (h *Handler) GenerateUserAPIKey(c *gin.Context) { success(c, apiKey, "API key generated successfully") } -// DeleteUserAPIKey handle delete user API key -func (h *Handler) DeleteUserAPIKey(c *gin.Context) { +// DeleteUserAPIToken handle delete user API key +func (h *Handler) DeleteUserAPIToken(c *gin.Context) { username := c.Param("username") - key := c.Param("key") + key := c.Param("token") if username == "" || key == "" { errorResponse(c, "Username and key are required", 400) return } - if err := h.service.DeleteUserAPIKey(username, key); err != nil { + if err := h.service.DeleteUserAPIToken(username, key); err != nil { errorResponse(c, err.Error(), 404) return } diff --git a/internal/admin/router.go b/internal/admin/router.go index dbfe7d8cb..413b5e4b6 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -67,9 +67,12 @@ func (r *Router) Setup(engine *gin.Engine) { protected.GET("/users/:username/agents", r.handler.GetUserAgents) // API Keys - protected.GET("/users/:username/keys", r.handler.GetUserAPIKeys) - protected.POST("/users/:username/keys", r.handler.GenerateUserAPIKey) - protected.DELETE("/users/:username/keys/:key", r.handler.DeleteUserAPIKey) + protected.GET("/users/:username/keys", r.handler.ListUserAPITokens) + protected.GET("/users/:username/tokens", r.handler.ListUserAPITokens) + protected.POST("/users/:username/keys", r.handler.GenerateUserAPIToken) + protected.POST("/users/:username/tokens", r.handler.GenerateUserAPIToken) + protected.DELETE("/users/:username/keys/:token", r.handler.DeleteUserAPIToken) + protected.DELETE("/users/:username/tokens/:token", r.handler.DeleteUserAPIToken) // Role management protected.GET("/roles", r.handler.ListRoles) diff --git a/internal/admin/service.go b/internal/admin/service.go index 076537db6..de20805bc 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -676,7 +676,7 @@ func (s *Service) DeleteUser(username string) (*DeleteUserResult, error) { var userTenantCount int64 tx.Model(&model.UserTenant{}).Where("user_id = ?", user.ID).Count(&userTenantCount) result.UserTenantCount = int(userTenantCount) - + // 15. Delete user-tenant relations if delErr := tx.Unscoped().Where("user_id = ?", user.ID).Delete(&model.UserTenant{}); delErr.Error != nil { logger.Warn("failed to delete user-tenant relations", zap.Error(delErr.Error)) @@ -868,20 +868,20 @@ func (s *Service) GetUserAgents(username string) ([]map[string]interface{}, erro // API Key methods -// GetUserAPIKeys get user API keys -func (s *Service) GetUserAPIKeys(username string) ([]map[string]interface{}, error) { +// ListUserAPITokens get user API keys +func (s *Service) ListUserAPITokens(username string) ([]map[string]interface{}, error) { // TODO: Implement get API keys return []map[string]interface{}{}, nil } -// GenerateUserAPIKey generate API key for user -func (s *Service) GenerateUserAPIKey(username string) (map[string]interface{}, error) { +// GenerateUserAPIToken generate API key for user +func (s *Service) GenerateUserAPIToken(username string) (map[string]interface{}, error) { // TODO: Implement generate API key return map[string]interface{}{}, nil } -// DeleteUserAPIKey delete user API key -func (s *Service) DeleteUserAPIKey(username, key string) error { +// DeleteUserAPIToken delete user API key +func (s *Service) DeleteUserAPIToken(username, key string) error { // TODO: Implement delete API key return nil } diff --git a/internal/cli/admin_command.go b/internal/cli/admin_command.go new file mode 100644 index 000000000..ed6dac458 --- /dev/null +++ b/internal/cli/admin_command.go @@ -0,0 +1,1101 @@ +package cli + +import ( + "encoding/json" + "fmt" + "net/url" +) + +// PingServer pings the server to check if it's alive +// Returns benchmark result map if iterations > 1, otherwise prints status +func (c *RAGFlowClient) PingAdmin(cmd *Command) (ResponseIf, error) { + // Get iterations from command params (for benchmark) + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode: multiple iterations + return c.HTTPClient.RequestWithIterations("GET", "/admin/ping", false, "web", nil, nil, iterations) + } + + // Single mode + resp, err := c.HTTPClient.Request("GET", "/admin/ping", true, "web", nil, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Server is down") + return nil, err + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ping: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list users failed: invalid JSON (%w)", err) + } + result.Duration = resp.Duration + return &result, nil +} + +// Show admin version to show RAGFlow admin version +// Returns benchmark result map if iterations > 1, otherwise prints status +func (c *RAGFlowClient) ShowAdminVersion(cmd *Command) (ResponseIf, error) { + // Get iterations from command params (for benchmark) + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode: multiple iterations + return c.HTTPClient.RequestWithIterations("GET", "/admin/version", false, "web", nil, nil, iterations) + } + + // Single mode + resp, err := c.HTTPClient.Request("GET", "/admin/version", true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show admin version: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show admin version: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show admin version failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// ListRoles to list roles (admin mode only) +func (c *RAGFlowClient) ListRoles(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", "/admin/roles", true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/roles", true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list roles: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list roles: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list roles failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + for _, user := range result.Data { + delete(user, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// ShowRole to show role (admin mode only) +func (c *RAGFlowClient) ShowRole(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + roleName := cmd.Params["role_name"].(string) + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + endPoint := fmt.Sprintf("/admin/roles/%s/", roleName) + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", endPoint, true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show role: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show role: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show role failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +// CreateRole creates a new role (admin mode only) +func (c *RAGFlowClient) CreateRole(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + roleName, ok := cmd.Params["role_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + description, ok := cmd.Params["description"].(string) + payload := map[string]interface{}{ + "role_name": roleName, + } + if ok { + payload["description"] = description + } + + resp, err := c.HTTPClient.Request("POST", "/admin/roles", true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to create role: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to create role: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("create role failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// DropRole deletes the role (admin mode only) +func (c *RAGFlowClient) DropRole(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + roleName, ok := cmd.Params["role_name"].(string) + if !ok { + return nil, fmt.Errorf("role_name not provided") + } + + resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/roles/%s", roleName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to drop role: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to drop role: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("drop role failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// AlterRole alters the role rights (admin mode only) +func (c *RAGFlowClient) AlterRole(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + roleName, ok := cmd.Params["role_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + description, ok := cmd.Params["description"].(string) + payload := map[string]interface{}{ + "role_name": roleName, + } + if ok { + payload["description"] = description + } + + resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/roles/%s", roleName), true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to alter role: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to alter role: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("alter role failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// GrantAdmin grants admin privileges to a user (admin mode only) +func (c *RAGFlowClient) GrantAdmin(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/admin", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to grant admin: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to grant admin: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("grant admin failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// RevokeAdmin revokes admin privileges from a user (admin mode only) +func (c *RAGFlowClient) RevokeAdmin(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/users/%s/admin", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to revoke admin: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to revoke admin: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("revoke admin failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// CreateUser creates a new user (admin mode only) +func (c *RAGFlowClient) CreateUser(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + password, ok := cmd.Params["password"].(string) + if !ok { + return nil, fmt.Errorf("password not provided") + } + + // Encrypt password using RSA + encryptedPassword, err := EncryptPassword(password) + if err != nil { + return nil, fmt.Errorf("failed to encrypt password: %w", err) + } + + payload := map[string]interface{}{ + "username": userName, + "password": encryptedPassword, + "role": "user", + } + + resp, err := c.HTTPClient.Request("POST", "/admin/users", true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to create user: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to create user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("create user failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// ActivateUser activates or deactivates a user (admin mode only) +func (c *RAGFlowClient) ActivateUser(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + activateStatus, ok := cmd.Params["activate_status"].(string) + if !ok { + return nil, fmt.Errorf("activate_status not provided") + } + + // Validate activate_status + if activateStatus != "on" && activateStatus != "off" { + return nil, fmt.Errorf("activate_status must be 'on' or 'off'") + } + + payload := map[string]interface{}{ + "activate_status": activateStatus, + } + + resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/activate", userName), true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to update user status: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to update user status: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("update user status failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// AlterUserPassword changes a user's password (admin mode only) +func (c *RAGFlowClient) AlterUserPassword(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + password, ok := cmd.Params["password"].(string) + if !ok { + return nil, fmt.Errorf("password not provided") + } + + // Encrypt password using RSA + encryptedPassword, err := EncryptPassword(password) + if err != nil { + return nil, fmt.Errorf("failed to encrypt password: %w", err) + } + + payload := map[string]interface{}{ + "new_password": encryptedPassword, + } + + resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/password", userName), true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to change user password: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to change user password: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("change user password failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +type listServicesResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` +} + +// ListServices lists all services (admin mode only) +func (c *RAGFlowClient) ListServices(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", "/admin/services", true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/services", true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list services: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list services: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list users failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + for _, user := range result.Data { + delete(user, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// Show service show service (admin mode only) +func (c *RAGFlowClient) ShowService(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + serviceIndex := cmd.Params["number"].(int) + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + endPoint := fmt.Sprintf("/admin/services/%d", serviceIndex) + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", endPoint, true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show service: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show service: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show service failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +// ListUsers lists all users (admin mode only) +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) ListUsers(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", "/admin/users", true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/users", true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list users: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list users: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list users failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + for _, user := range result.Data { + delete(user, "create_date") + } + + result.Duration = resp.Duration + return &result, nil +} + +// DropUser deletes a user (admin mode only) +func (c *RAGFlowClient) DropUser(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/users/%s", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to drop user: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to drop user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("drop user failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// Show user show user (admin mode only) +func (c *RAGFlowClient) ShowUser(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show user: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show user failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// ListDatasets lists datasets for a specific user (admin mode) +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) ListDatasets(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list datasets: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + resJSON, err := resp.JSON() + if err != nil { + return nil, fmt.Errorf("invalid JSON response: %w", err) + } + + data, ok := resJSON["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format") + } + + // Convert to slice of maps and remove avatar + tableData := make([]map[string]interface{}, 0, len(data)) + for _, item := range data { + if itemMap, ok := item.(map[string]interface{}); ok { + delete(itemMap, "avatar") + tableData = append(tableData, itemMap) + } + } + + PrintTableSimple(tableData) + return nil, nil +} + +// ListAgents lists agents for a specific user (admin mode) +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) ListAgents(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", fmt.Sprintf("/admin/users/%s/agents", userName), true, "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/agents", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list agents: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list agents: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + resJSON, err := resp.JSON() + if err != nil { + return nil, fmt.Errorf("invalid JSON response: %w", err) + } + + data, ok := resJSON["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format") + } + + // Convert to slice of maps and remove avatar + tableData := make([]map[string]interface{}, 0, len(data)) + for _, item := range data { + if itemMap, ok := item.(map[string]interface{}); ok { + delete(itemMap, "avatar") + tableData = append(tableData, itemMap) + } + } + + PrintTableSimple(tableData) + return nil, nil +} + +// GrantPermission grants permission to a role (admin mode only) +func (c *RAGFlowClient) GrantPermission(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/keys", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list tokens: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list tokens: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list tokens failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + // Remove extra field from data + for _, item := range result.Data { + delete(item, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// RevokePermission revokes permission from a role (admin mode only) +func (c *RAGFlowClient) RevokePermission(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + roleName, ok := cmd.Params["role_name"].(string) + if !ok { + return nil, fmt.Errorf("role_name not provided") + } + + resource, ok := cmd.Params["resource"].(string) + if !ok { + return nil, fmt.Errorf("resource not provided") + } + + actionsRaw, ok := cmd.Params["actions"].([]interface{}) + if !ok { + return nil, fmt.Errorf("actions not provided") + } + + actions := make([]string, 0, len(actionsRaw)) + for _, action := range actionsRaw { + if actionStr, ok := action.(string); ok { + actions = append(actions, actionStr) + } + } + + payload := map[string]interface{}{ + "resource": resource, + "actions": actions, + } + + resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/roles/%s/permission", roleName), true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to revoke permission: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to revoke permission: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("revoke permission failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + // Remove extra field from data + for _, item := range result.Data { + delete(item, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// AlterUserRole alters user's role (admin mode only) +func (c *RAGFlowClient) AlterUserRole(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + roleName, ok := cmd.Params["role_name"].(string) + if !ok { + return nil, fmt.Errorf("role_name not provided") + } + + payload := map[string]interface{}{ + "role_name": roleName, + } + + resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/role", userName), true, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to alter user role: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to alter user role: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("alter user role failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + // Remove extra field from data + for _, item := range result.Data { + delete(item, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// ShowUserPermission shows user's permissions (admin mode only) +func (c *RAGFlowClient) ShowUserPermission(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/permission", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show user permission: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show user permission: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show user permission failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + // Remove extra field from data + for _, item := range result.Data { + delete(item, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// CreateAdminToken generates an API token for a user (admin mode only) +func (c *RAGFlowClient) CreateAdminToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("POST", fmt.Sprintf("/admin/users/%s/keys", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to generate token: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to generate token: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("generate token failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + // Remove extra field from data + for _, item := range result.Data { + delete(item, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// ListAdminTokens lists all API tokens for a user (admin mode only) +func (c *RAGFlowClient) ListAdminTokens(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/keys", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list tokens: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list tokens: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list tokens failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + // Remove extra field from data + for _, item := range result.Data { + delete(item, "extra") + } + + result.Duration = resp.Duration + return &result, nil +} + +// DropToken drops an API token for a user (admin mode only) +func (c *RAGFlowClient) DropAdminToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + token, ok := cmd.Params["token"].(string) + if !ok { + return nil, fmt.Errorf("token not provided") + } + + // URL encode the token to handle special characters + encodedToken := url.QueryEscape(token) + + resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/users/%s/keys/%s", userName, encodedToken), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to drop token: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to drop token: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("drop token failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} diff --git a/internal/cli/benchmark.go b/internal/cli/benchmark.go index 872c830e3..ab4d025c3 100644 --- a/internal/cli/benchmark.go +++ b/internal/cli/benchmark.go @@ -34,7 +34,7 @@ type BenchmarkResult struct { } // RunBenchmark runs a benchmark with the given concurrency and iterations -func (c *RAGFlowClient) RunBenchmark(cmd *Command) error { +func (c *RAGFlowClient) RunBenchmark(cmd *Command) (ResponseIf, error) { concurrency, ok := cmd.Params["concurrency"].(int) if !ok { concurrency = 1 @@ -47,29 +47,26 @@ func (c *RAGFlowClient) RunBenchmark(cmd *Command) error { nestedCmd, ok := cmd.Params["command"].(*Command) if !ok { - return fmt.Errorf("benchmark command not found") + return nil, fmt.Errorf("benchmark command not found") } if concurrency < 1 { - return fmt.Errorf("concurrency must be greater than 0") + return nil, fmt.Errorf("concurrency must be greater than 0") } // Add iterations to the nested command nestedCmd.Params["iterations"] = iterations if concurrency == 1 { - return c.runBenchmarkSingle(concurrency, iterations, nestedCmd) + return c.runBenchmarkSingle(iterations, nestedCmd) } return c.runBenchmarkConcurrent(concurrency, iterations, nestedCmd) } // runBenchmarkSingle runs benchmark with single concurrency (sequential execution) -func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCmd *Command) error { +func (c *RAGFlowClient) runBenchmarkSingle(iterations int, nestedCmd *Command) (*BenchmarkResponse, error) { commandType := nestedCmd.Type - startTime := time.Now() - responseList := make([]*Response, 0, iterations) - // For search_on_datasets, convert dataset names to IDs first if commandType == "search_on_datasets" && iterations > 1 { datasets, _ := nestedCmd.Params["datasets"].(string) @@ -79,7 +76,7 @@ func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCm name = strings.TrimSpace(name) id, err := c.getDatasetID(name) if err != nil { - return err + return nil, err } datasetIDs = append(datasetIDs, id) } @@ -87,86 +84,67 @@ func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCm } // Check if command supports native benchmark (iterations > 1) - supportsNative := false if iterations > 1 { result, err := c.ExecuteCommand(nestedCmd) - if err == nil && result != nil { - // Command supports benchmark natively - supportsNative = true - duration, _ := result["duration"].(float64) - respList, _ := result["response_list"].([]*Response) - responseList = respList + // convert result to BenchmarkResponse + benchmarkResponse := result.(*BenchmarkResponse) + benchmarkResponse.Concurrency = 1 + return benchmarkResponse, err + } - // Calculate and print results - successCount := 0 - for _, resp := range responseList { - if isSuccess(resp, commandType) { - successCount++ - } - } + result, err := c.ExecuteCommand(nestedCmd) + if err != nil { + fmt.Printf("fail to execute: %s", commandType) + return nil, err + } - qps := float64(0) - if duration > 0 { - qps = float64(iterations) / duration - } - - fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations) - fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n", - duration, qps, iterations, successCount, iterations-successCount) - return nil + var benchmarkResponse BenchmarkResponse + switch result.Type() { + case "common": + commonResponse := result.(*CommonResponse) + benchmarkResponse.Code = commonResponse.Code + benchmarkResponse.Duration = commonResponse.Duration + if commonResponse.Code == 0 { + benchmarkResponse.SuccessCount = 1 + } else { + benchmarkResponse.FailureCount = 1 } - } - - // Manual execution: run iterations times - if !supportsNative { - // Remove iterations param to avoid native benchmark - delete(nestedCmd.Params, "iterations") - - for i := 0; i < iterations; i++ { - singleResult, err := c.ExecuteCommand(nestedCmd) - if err != nil { - // Command failed, add a failed response - responseList = append(responseList, &Response{StatusCode: 0}) - continue - } - - // For commands that return a single response (like ping with iterations=1) - if singleResult != nil { - if respList, ok := singleResult["response_list"].([]*Response); ok { - responseList = append(responseList, respList...) - } - } else { - // Command executed successfully but returned no data - // Mark as success for now - responseList = append(responseList, &Response{StatusCode: 200, Body: []byte("pong")}) - } + case "simple": + simpleResponse := result.(*SimpleResponse) + benchmarkResponse.Code = simpleResponse.Code + benchmarkResponse.Duration = simpleResponse.Duration + if simpleResponse.Code == 0 { + benchmarkResponse.SuccessCount = 1 + } else { + benchmarkResponse.FailureCount = 1 } - } - - duration := time.Since(startTime).Seconds() - - successCount := 0 - for _, resp := range responseList { - if isSuccess(resp, commandType) { - successCount++ + case "show": + dataResponse := result.(*CommonDataResponse) + benchmarkResponse.Code = dataResponse.Code + benchmarkResponse.Duration = dataResponse.Duration + if dataResponse.Code == 0 { + benchmarkResponse.SuccessCount = 1 + } else { + benchmarkResponse.FailureCount = 1 } + case "data": + kvResponse := result.(*KeyValueResponse) + benchmarkResponse.Code = kvResponse.Code + benchmarkResponse.Duration = kvResponse.Duration + if kvResponse.Code == 0 { + benchmarkResponse.SuccessCount = 1 + } else { + benchmarkResponse.FailureCount = 1 + } + default: + return nil, fmt.Errorf("unsupported command type: %s", result.Type()) } - - qps := float64(0) - if duration > 0 { - qps = float64(iterations) / duration - } - - // Print results - fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations) - fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n", - duration, qps, iterations, successCount, iterations-successCount) - - return nil + benchmarkResponse.Concurrency = 1 + return &benchmarkResponse, nil } // runBenchmarkConcurrent runs benchmark with multiple concurrent workers -func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nestedCmd *Command) error { +func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nestedCmd *Command) (*BenchmarkResponse, error) { results := make([]map[string]interface{}, concurrency) var wg sync.WaitGroup @@ -179,7 +157,7 @@ func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nest name = strings.TrimSpace(name) id, err := c.getDatasetID(name) if err != nil { - return err + return nil, err } datasetIDs = append(datasetIDs, id) } @@ -228,17 +206,15 @@ func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nest } totalCommands := iterations * concurrency - qps := float64(0) - if totalDuration > 0 { - qps = float64(totalCommands) / totalDuration - } - // Print results - fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations) - fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n", - totalDuration, qps, totalCommands, successCount, totalCommands-successCount) + var benchmarkResponse BenchmarkResponse + benchmarkResponse.Duration = totalDuration + benchmarkResponse.Code = 0 + benchmarkResponse.SuccessCount = successCount + benchmarkResponse.FailureCount = totalCommands - successCount + benchmarkResponse.Concurrency = concurrency - return nil + return &benchmarkResponse, nil } // executeBenchmarkSilent executes a command for benchmark without printing output @@ -250,7 +226,7 @@ func (c *RAGFlowClient) executeBenchmarkSilent(cmd *Command, iterations int) []* var err error switch cmd.Type { - case "ping_server": + case "ping": resp, err = c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil) case "list_user_datasets": resp, err = c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) @@ -290,7 +266,7 @@ func isSuccess(resp *Response, commandType string) bool { } switch commandType { - case "ping_server": + case "ping": return resp.StatusCode == 200 && string(resp.Body) == "pong" case "list_user_datasets", "list_datasets", "search_on_datasets": // Check status code and JSON response code for dataset commands diff --git a/internal/cli/cli.go b/internal/cli/cli.go index b74344815..fbedb8f0f 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -17,15 +17,285 @@ package cli import ( + "errors" + "flag" "fmt" "os" "os/signal" "strconv" "strings" "syscall" + "github.com/peterh/liner" + "golang.org/x/term" + "gopkg.in/yaml.v3" ) +// ConfigFile represents the rf.yml configuration file structure +type ConfigFile struct { + Host string `yaml:"host"` + User string `yaml:"user"` + Password string `yaml:"password"` + APIToken string `yaml:"api_token"` +} + +// ConnectionArgs holds the parsed command line arguments +type ConnectionArgs struct { + Host string + Port int + Password string + Key string + Type string + Username string + Command string + ShowHelp bool +} + +// LoadDefaultConfigFile reads the rf.yml file from current directory if it exists +func LoadDefaultConfigFile() (*ConfigFile, error) { + // Try to read rf.yml from current directory + data, err := os.ReadFile("rf.yml") + if err != nil { + // File doesn't exist, return nil without error + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var config ConfigFile + if err = yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse rf.yml: %v", err) + } + + return &config, nil +} + +// LoadConfigFileFromPath reads a config file from the specified path +func LoadConfigFileFromPath(path string) (*ConfigFile, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read config file %s: %v", path, err) + } + + var config ConfigFile + if err = yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse config file %s: %v", path, err) + } + + return &config, nil +} + +// parseHostPort parses a host:port string and returns host and port +func parseHostPort(hostPort string) (string, int, error) { + if hostPort == "" { + return "", -1, nil + } + + // Split host and port + parts := strings.Split(hostPort, ":") + if len(parts) != 2 { + return "", -1, fmt.Errorf("invalid host format, expected host:port, got: %s", hostPort) + } + + host := parts[0] + port, err := strconv.Atoi(parts[1]) + if err != nil { + return "", -1, fmt.Errorf("invalid port number: %s", parts[1]) + } + + return host, port, nil +} + +// ParseConnectionArgs parses command line arguments similar to Python's parse_connection_args +func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { + // First, scan args to check for help and config file + var configFilePath string + var hasOtherFlags bool + + for i := 0; i < len(args); i++ { + arg := args[i] + if arg == "--help" { + return &ConnectionArgs{ShowHelp: true}, nil + } else if arg == "-f" && i+1 < len(args) { + configFilePath = args[i+1] + i++ + } else if strings.HasPrefix(arg, "-") && arg != "-f" { + // Check if it's a flag (not a command) + if arg == "-h" || arg == "-p" || arg == "-w" || arg == "-k" || + arg == "-u" || arg == "-admin" || arg == "-user" { + hasOtherFlags = true + } + } + } + + // Load config file with priority: -f > rf.yml > none + var config *ConfigFile + var err error + + if configFilePath != "" { + // User specified config file via -f + config, err = LoadConfigFileFromPath(configFilePath) + if err != nil { + return nil, err + } + } else { + // Try default rf.yml + config, err = LoadDefaultConfigFile() + if err != nil { + return nil, err + } + } + + if config != nil { + if hasOtherFlags { + return nil, fmt.Errorf("cannot use command line flags (-h, -p, -w, -k, -u, -admin, -user) when using config file. Please use config file or command line flags, not both") + } + + return buildArgsFromConfig(config, args) + } + // Create a new flag set + fs := flag.NewFlagSet("ragflow_cli", flag.ContinueOnError) + + // Define flags + host := fs.String("h", "127.0.0.1", "Admin or RAGFlow service host") + port := fs.Int("p", -1, "Admin or RAGFlow service port (default: 9381 for admin, 9380 for user)") + password := fs.String("w", "", "Superuser password") + key := fs.String("k", "", "API key for authentication") + _ = fs.String("f", "", "Path to config file (YAML format)") // Already parsed above + _ = fs.Bool("admin", false, "Run in admin mode (default)") + userMode := fs.Bool("user", false, "Run in user mode") + username := fs.String("u", "", "Username (email). In admin mode defaults to admin@ragflow.io, in user mode required") + + // Parse the arguments + if err = fs.Parse(args); err != nil { + return nil, fmt.Errorf("failed to parse arguments: %v", err) + } + + // Otherwise, use command line flags + return buildArgsFromFlags(host, port, password, key, userMode, username, fs.Args()) +} + +// buildArgsFromConfig builds ConnectionArgs from config file +func buildArgsFromConfig(config *ConfigFile, remainingArgs []string) (*ConnectionArgs, error) { + result := &ConnectionArgs{} + + // Parse host:port from config file + if config.Host != "" { + host, port, err := parseHostPort(config.Host) + if err != nil { + return nil, fmt.Errorf("invalid host in config file: %v", err) + } + result.Host = host + result.Port = port + } else { + result.Host = "127.0.0.1" + } + + // Apply auth info from config + result.Username = config.User + result.Password = config.Password + result.Key = config.APIToken + + // Determine mode: if config has auth info, use user mode + if config.User != "" || config.APIToken != "" { + result.Type = "user" + } else { + result.Type = "admin" + result.Username = "admin@ragflow.io" + } + + // Set default port if not specified in config + if result.Port == -1 { + if result.Type == "admin" { + result.Port = 9381 + } else { + result.Port = 9380 + } + } + + // Get command from remaining args (no need for quotes or semicolon) + if len(remainingArgs) > 0 { + result.Command = strings.Join(remainingArgs, " ") + ";" + } + + return result, nil +} + +// buildArgsFromFlags builds ConnectionArgs from command line flags +func buildArgsFromFlags(host *string, port *int, password *string, key *string, userMode *bool, username *string, remainingArgs []string) (*ConnectionArgs, error) { + result := &ConnectionArgs{ + Host: *host, + Port: *port, + Password: *password, + Key: *key, + Username: *username, + } + + // Determine mode + if *userMode { + result.Type = "user" + } else { + result.Type = "admin" + } + + // Set default port based on type if not specified + if result.Port == -1 { + if result.Type == "admin" { + result.Port = 9383 + } else { + result.Port = 9384 + } + } + + // Determine username based on mode + if result.Type == "admin" && result.Username == "" { + result.Username = "admin@ragflow.io" + } + + // Get command from remaining args (no need for quotes or semicolon) + if len(remainingArgs) > 0 { + result.Command = strings.Join(remainingArgs, " ") + ";" + } + + return result, nil +} + +// PrintUsage prints the CLI usage information +func PrintUsage() { + fmt.Println(`RAGFlow CLI Client + +Usage: ragflow_cli [options] [command] + +Options: + -h string Admin or RAGFlow service host (default "127.0.0.1") + -p int Admin or RAGFlow service port (default 9381 for admin, 9380 for user) + -w string Superuser password + -k string API key for authentication + -f string Path to config file (YAML format) + -admin Run in admin mode (default) + -user Run in user mode + -u string Username (email). In admin mode defaults to admin@ragflow.io + --help Show this help message + +Configuration File: + The CLI will automatically read rf.yml from the current directory if it exists. + Use -f to specify a custom config file path. + + Config file format: + host: 127.0.0.1:9380 + user: your-email@example.com + password: your-password + api_token: your-api-token + + When using a config file, you cannot use other command line flags except -help. + The command line is only for the SQL command. + +Commands: + SQL commands like: LOGIN USER 'email'; LIST USERS; etc. +`) +} + // HistoryFile returns the path to the history file func HistoryFile() string { return os.Getenv("HOME") + "/" + historyFileName @@ -39,26 +309,95 @@ type CLI struct { prompt string running bool line *liner.State + args *ConnectionArgs } // NewCLI creates a new CLI instance func NewCLI() (*CLI, error) { + return NewCLIWithArgs(nil) +} + +// NewCLIWithArgs creates a new CLI instance with connection arguments +func NewCLIWithArgs(args *ConnectionArgs) (*CLI, error) { // Create liner first line := liner.NewLiner() + // Determine server type + serverType := "user" + if args != nil && args.Type != "" { + serverType = args.Type + } + // Create client with password prompt using liner - client := NewRAGFlowClient("user") // Default to user mode + client := NewRAGFlowClient(serverType) client.PasswordPrompt = line.PasswordPrompt + // Apply connection arguments if provided + client.HTTPClient.Host = args.Host + client.HTTPClient.Port = args.Port + + // Set prompt based on server type + prompt := "RAGFlow> " + if serverType == "admin" { + prompt = "RAGFlow(admin)> " + } + return &CLI{ - prompt: "RAGFlow> ", + prompt: prompt, client: client, line: line, + args: args, }, nil } // Run starts the interactive CLI func (c *CLI) Run() error { + if c.args.Type == "admin" { + // Allow 3 attempts for password verification + maxAttempts := 3 + for attempt := 1; attempt <= maxAttempts; attempt++ { + var input string + var err error + + // Check if terminal supports password masking + if term.IsTerminal(int(os.Stdin.Fd())) { + input, err = c.line.PasswordPrompt("Please input your password: ") + } else { + // Terminal doesn't support password masking, use regular prompt + fmt.Println("Warning: This terminal does not support secure password input") + input, err = c.line.Prompt("Please input your password (will be visible): ") + } + if err != nil { + fmt.Printf("Error reading input: %v\n", err) + return err + } + + input = strings.TrimSpace(input) + + if input == "" { + if attempt < maxAttempts { + fmt.Println("Password cannot be empty, please try again") + continue + } + return errors.New("no password provided after 3 attempts") + } + + // Set the password for verification + c.args.Password = input + + if err = c.VerifyAuth(); err != nil { + if attempt < maxAttempts { + fmt.Printf("Authentication failed: %v (%d/%d attempts)\n", err, attempt, maxAttempts) + continue + } + return fmt.Errorf("authentication failed after %d attempts: %v", maxAttempts, err) + } + + // Authentication successful + break + } + } + c.running = true // Load history from file @@ -99,8 +438,8 @@ func (c *CLI) Run() error { c.line.AppendHistory(input) } - if err := c.execute(input); err != nil { - fmt.Printf("Error: %v\n", err) + if err = c.execute(input); err != nil { + fmt.Printf("CLI error: %v\n", err) } } @@ -124,7 +463,11 @@ func (c *CLI) execute(input string) error { } // Execute the command using the client - _, err = c.client.ExecuteCommand(cmd) + var result ResponseIf + result, err = c.client.ExecuteCommand(cmd) + if result != nil { + result.PrintOut() + } return err } @@ -143,14 +486,12 @@ func (c *CLI) handleMetaCommand(cmd *Command) error { fmt.Print("\033[H\033[2J") case "admin": c.client.ServerType = "admin" - c.client.HTTPClient.Port = 9381 c.prompt = "RAGFlow(admin)> " - fmt.Println("Switched to ADMIN mode (port 9381)") + fmt.Println("Switched to ADMIN mode") case "user": c.client.ServerType = "user" - c.client.HTTPClient.Port = 9380 c.prompt = "RAGFlow> " - fmt.Println("Switched to USER mode (port 9380)") + fmt.Println("Switched to USER mode") case "host": if len(args) == 0 { fmt.Printf("Current host: %s\n", c.client.HTTPClient.Host) @@ -195,7 +536,7 @@ Meta Commands: \q or \quit - Exit CLI \c or \clear - Clear screen -SQL Commands (User Mode): +Commands (User Mode): LOGIN USER 'email'; - Login as user REGISTER USER 'name' AS 'nickname' PASSWORD 'pwd'; - Register new user SHOW VERSION; - Show version info @@ -205,9 +546,14 @@ SQL Commands (User Mode): LIST CHATS; - List user chats LIST MODEL PROVIDERS; - List model providers LIST DEFAULT MODELS; - List default models + LIST TOKENS; - List API tokens + CREATE TOKEN; - Create new API token + DROP TOKEN 'token_value'; - Delete an API token + SET TOKEN 'token_value'; - Set and validate API token + SHOW TOKEN; - Show current API token + UNSET TOKEN; - Remove current API token -SQL Commands (Admin Mode): - LOGIN USER 'email'; - Login as admin +Commands (Admin Mode): LIST USERS; - List all users SHOW USER 'email'; - Show user details CREATE USER 'email' 'password'; - Create new user @@ -254,3 +600,34 @@ func RunInteractive() error { return cli.Run() } + +// RunSingleCommand executes a single command and exits +func (c *CLI) RunSingleCommand(command string) error { + // Execute the command + if err := c.execute(command); err != nil { + return err + } + return nil +} + +// VerifyAuth verifies authentication if needed +func (c *CLI) VerifyAuth() error { + if c.args == nil { + return nil + } + + if c.args.Username == "" { + return fmt.Errorf("username is required") + } + + if c.args.Password == "" { + return fmt.Errorf("password is required") + } + + // Create login command with username and password + cmd := NewCommand("login_user") + cmd.Params["email"] = c.args.Username + cmd.Params["password"] = c.args.Password + _, err := c.client.ExecuteCommand(cmd) + return err +} diff --git a/internal/cli/client.go b/internal/cli/client.go index af5457d09..4801feacd 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -18,6 +18,7 @@ package cli import ( "bufio" + "encoding/json" "fmt" "os" "os/exec" @@ -83,7 +84,7 @@ func (c *RAGFlowClient) LoginUserInteractive(username, password string) error { resJSON, err := resp.JSON() if err == nil { // Admin mode returns {"code":0,"message":"PONG"} - if msg, ok := resJSON["message"].(string); !ok || msg != "PONG" { + if msg, ok := resJSON["message"].(string); !ok || msg != "pong" { fmt.Println("Server is down") return fmt.Errorf("server is down") } @@ -150,7 +151,7 @@ func (c *RAGFlowClient) LoginUser(cmd *Command) error { resJSON, err := resp.JSON() if err == nil { // Admin mode returns {"code":0,"message":"PONG"} - if msg, ok := resJSON["message"].(string); !ok || msg != "PONG" { + if msg, ok := resJSON["message"].(string); !ok || msg != "pong" { fmt.Println("Server is down") return fmt.Errorf("server is down") } @@ -167,23 +168,16 @@ func (c *RAGFlowClient) LoginUser(cmd *Command) error { return fmt.Errorf("email not provided") } - // Get password from user input (hidden) - var password string - if c.PasswordPrompt != nil { - pwd, err := c.PasswordPrompt(fmt.Sprintf("password for %s: ", email)) - if err != nil { - return fmt.Errorf("failed to read password: %w", err) - } - password = pwd - } else { + password, ok := cmd.Params["password"].(string) + if !ok { + // Get password from user input (hidden) fmt.Printf("password for %s: ", email) - pwd, err := readPassword() + password, err = readPassword() if err != nil { return fmt.Errorf("failed to read password: %w", err) } - password = pwd + password = strings.TrimSpace(password) } - password = strings.TrimSpace(password) // Login token, err := c.loginUser(email, password) @@ -223,15 +217,13 @@ func (c *RAGFlowClient) loginUser(email, password string) (string, error) { return "", err } - resJSON, err := resp.JSON() - if err != nil { - return "", fmt.Errorf("login failed: invalid JSON response (%w)", err) + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return "", fmt.Errorf("login failed: invalid JSON (%w)", err) } - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return "", fmt.Errorf("login failed: %s", msg) + if result.Code != 0 { + return "", fmt.Errorf("login failed: %s", result.Message) } token := resp.Headers.Get("Authorization") @@ -242,160 +234,6 @@ func (c *RAGFlowClient) loginUser(email, password string) (string, error) { return token, nil } -// PingServer pings the server to check if it's alive -// Returns benchmark result map if iterations > 1, otherwise prints status -func (c *RAGFlowClient) PingServer(cmd *Command) (map[string]interface{}, error) { - // Get iterations from command params (for benchmark) - iterations := 1 - if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { - iterations = val - } - - if iterations > 1 { - // Benchmark mode: multiple iterations - result, err := c.HTTPClient.RequestWithIterations("GET", "/system/ping", false, "web", nil, nil, iterations) - if err != nil { - return nil, err - } - return result, nil - } - - // Single ping mode - resp, err := c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil) - if err != nil { - fmt.Printf("Error: %v\n", err) - fmt.Println("Server is down") - return nil, err - } - - if resp.StatusCode == 200 && string(resp.Body) == "pong" { - fmt.Println("Server is alive") - } else { - fmt.Printf("Error: %d\n", resp.StatusCode) - } - return nil, nil -} - -// ListUserDatasets lists datasets for current user (user mode) -// Returns (result_map, error) - result_map is non-nil for benchmark mode -func (c *RAGFlowClient) ListUserDatasets(cmd *Command) (map[string]interface{}, error) { - if c.ServerType != "user" { - return nil, fmt.Errorf("this command is only allowed in USER mode") - } - - // Check for benchmark iterations - iterations := 1 - if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { - iterations = val - } - - if iterations > 1 { - // Benchmark mode - return raw result for benchmark stats - return c.HTTPClient.RequestWithIterations("POST", "/kb/list", false, "web", nil, nil, iterations) - } - - // Normal mode - resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to list datasets: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return nil, fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return nil, fmt.Errorf("failed to list datasets: %s", msg) - } - - data, ok := resJSON["data"].(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid response format") - } - - kbs, ok := data["kbs"].([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid response format: kbs not found") - } - - // Convert to slice of maps - tableData := make([]map[string]interface{}, 0, len(kbs)) - for _, kb := range kbs { - if kbMap, ok := kb.(map[string]interface{}); ok { - // Remove avatar field - delete(kbMap, "avatar") - tableData = append(tableData, kbMap) - } - } - - PrintTableSimple(tableData) - return nil, nil -} - -// ListDatasets lists datasets for a specific user (admin mode) -// Returns (result_map, error) - result_map is non-nil for benchmark mode -func (c *RAGFlowClient) ListDatasets(cmd *Command) (map[string]interface{}, error) { - if c.ServerType != "admin" { - return nil, fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return nil, fmt.Errorf("user_name not provided") - } - - // Check for benchmark iterations - iterations := 1 - if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { - iterations = val - } - - if iterations > 1 { - // Benchmark mode - return raw result for benchmark stats - return c.HTTPClient.RequestWithIterations("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil, iterations) - } - - fmt.Printf("Listing all datasets of user: %s\n", userName) - - resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to list datasets: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return nil, fmt.Errorf("invalid JSON response: %w", err) - } - - data, ok := resJSON["data"].([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid response format") - } - - // Convert to slice of maps and remove avatar - tableData := make([]map[string]interface{}, 0, len(data)) - for _, item := range data { - if itemMap, ok := item.(map[string]interface{}); ok { - delete(itemMap, "avatar") - tableData = append(tableData, itemMap) - } - } - - PrintTableSimple(tableData) - return nil, nil -} - // readPassword reads password from terminal without echoing func readPassword() (string, error) { // Check if stdin is a terminal by trying to get terminal size @@ -444,354 +282,100 @@ func readPasswordFallback() (string, error) { return strings.TrimSpace(password), nil } -// getDatasetID gets dataset ID by name -func (c *RAGFlowClient) getDatasetID(datasetName string) (string, error) { - resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) - if err != nil { - return "", fmt.Errorf("failed to list datasets: %w", err) - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("failed to list datasets: HTTP %d", resp.StatusCode) - } - - resJSON, err := resp.JSON() - if err != nil { - return "", fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return "", fmt.Errorf("failed to list datasets: %s", msg) - } - - data, ok := resJSON["data"].(map[string]interface{}) - if !ok { - return "", fmt.Errorf("invalid response format") - } - - kbs, ok := data["kbs"].([]interface{}) - if !ok { - return "", fmt.Errorf("invalid response format: kbs not found") - } - - for _, kb := range kbs { - if kbMap, ok := kb.(map[string]interface{}); ok { - if name, _ := kbMap["name"].(string); name == datasetName { - if id, _ := kbMap["id"].(string); id != "" { - return id, nil - } - } - } - } - - return "", fmt.Errorf("dataset '%s' not found", datasetName) -} - -// formatEmptyArray converts empty arrays to "[]" string -func formatEmptyArray(v interface{}) string { - if v == nil { - return "[]" - } - switch val := v.(type) { - case []interface{}: - if len(val) == 0 { - return "[]" - } - case []string: - if len(val) == 0 { - return "[]" - } - case []int: - if len(val) == 0 { - return "[]" - } - } - return fmt.Sprintf("%v", v) -} - -// SearchOnDatasets searches for chunks in specified datasets -// Returns (result_map, error) - result_map is non-nil for benchmark mode -func (c *RAGFlowClient) SearchOnDatasets(cmd *Command) (map[string]interface{}, error) { - if c.ServerType != "user" { - return nil, fmt.Errorf("this command is only allowed in USER mode") - } - - question, ok := cmd.Params["question"].(string) - if !ok { - return nil, fmt.Errorf("question not provided") - } - - datasets, ok := cmd.Params["datasets"].(string) - if !ok { - return nil, fmt.Errorf("datasets not provided") - } - - // Parse dataset names (comma-separated) and convert to IDs - datasetNames := strings.Split(datasets, ",") - datasetIDs := make([]string, 0, len(datasetNames)) - for _, name := range datasetNames { - name = strings.TrimSpace(name) - id, err := c.getDatasetID(name) - if err != nil { - return nil, err - } - datasetIDs = append(datasetIDs, id) - } - - // Check for benchmark iterations - iterations := 1 - if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { - iterations = val - } - - payload := map[string]interface{}{ - "kb_id": datasetIDs, - "question": question, - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - } - - if iterations > 1 { - // Benchmark mode - return raw result for benchmark stats - return c.HTTPClient.RequestWithIterations("POST", "/chunk/retrieval_test", false, "web", nil, payload, iterations) - } - - // Normal mode - resp, err := c.HTTPClient.Request("POST", "/chunk/retrieval_test", false, "web", nil, payload) - if err != nil { - return nil, fmt.Errorf("failed to search on datasets: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("failed to search on datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return nil, fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return nil, fmt.Errorf("failed to search on datasets: %s", msg) - } - - data, ok := resJSON["data"].(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid response format") - } - - chunks, ok := data["chunks"].([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid response format: chunks not found") - } - - // Convert to slice of maps for printing - tableData := make([]map[string]interface{}, 0, len(chunks)) - for _, chunk := range chunks { - if chunkMap, ok := chunk.(map[string]interface{}); ok { - row := map[string]interface{}{ - "id": chunkMap["chunk_id"], - "content": chunkMap["content_with_weight"], - "document_id": chunkMap["doc_id"], - "dataset_id": chunkMap["kb_id"], - "docnm_kwd": chunkMap["docnm_kwd"], - "image_id": chunkMap["image_id"], - "similarity": chunkMap["similarity"], - "term_similarity": chunkMap["term_similarity"], - "vector_similarity": chunkMap["vector_similarity"], - } - // Add optional fields that may be empty arrays - if v, ok := chunkMap["doc_type_kwd"]; ok { - row["doc_type_kwd"] = formatEmptyArray(v) - } - if v, ok := chunkMap["important_kwd"]; ok { - row["important_kwd"] = formatEmptyArray(v) - } - if v, ok := chunkMap["mom_id"]; ok { - row["mom_id"] = formatEmptyArray(v) - } - if v, ok := chunkMap["positions"]; ok { - row["positions"] = formatEmptyArray(v) - } - if v, ok := chunkMap["content_ltks"]; ok { - row["content_ltks"] = v - } - tableData = append(tableData, row) - } - } - - PrintTableSimple(tableData) - return nil, nil -} - // ExecuteCommand executes a parsed command // Returns benchmark result map for commands that support it (e.g., ping_server with iterations > 1) -func (c *RAGFlowClient) ExecuteCommand(cmd *Command) (map[string]interface{}, error) { +func (c *RAGFlowClient) ExecuteCommand(cmd *Command) (ResponseIf, error) { + switch c.ServerType { + case "admin": + // Admin mode: execute command with admin privileges + return c.ExecuteAdminCommand(cmd) + case "user": + // User mode: execute command with user privileges + return c.ExecuteUserCommand(cmd) + default: + return nil, fmt.Errorf("invalid server type: %s", c.ServerType) + } +} + +func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { switch cmd.Type { case "login_user": return nil, c.LoginUser(cmd) - case "ping_server": - return c.PingServer(cmd) + case "ping": + return c.PingAdmin(cmd) case "benchmark": - return nil, c.RunBenchmark(cmd) + return c.RunBenchmark(cmd) case "list_user_datasets": return c.ListUserDatasets(cmd) - case "list_datasets": - return c.ListDatasets(cmd) - case "search_on_datasets": - return c.SearchOnDatasets(cmd) case "list_users": return c.ListUsers(cmd) + case "list_services": + return c.ListServices(cmd) case "grant_admin": - return nil, c.GrantAdmin(cmd) + return c.GrantAdmin(cmd) case "revoke_admin": - return nil, c.RevokeAdmin(cmd) - case "show_current_user": - return c.ShowCurrentUser(cmd) + return c.RevokeAdmin(cmd) case "create_user": - return nil, c.CreateUser(cmd) + return c.CreateUser(cmd) case "activate_user": - return nil, c.ActivateUser(cmd) + return c.ActivateUser(cmd) case "alter_user": - return nil, c.AlterUserPassword(cmd) + return c.AlterUserPassword(cmd) case "drop_user": - return nil, c.DropUser(cmd) + return c.DropUser(cmd) + case "show_service": + return c.ShowService(cmd) + case "show_version": + return c.ShowAdminVersion(cmd) + case "show_user": + return c.ShowUser(cmd) + case "list_datasets": + return c.ListDatasets(cmd) + case "list_agents": + return c.ListAgents(cmd) + case "create_token": + return c.CreateAdminToken(cmd) + case "list_tokens": + return c.ListAdminTokens(cmd) + case "drop_token": + return c.DropAdminToken(cmd) // TODO: Implement other commands default: return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type) } } - -// ListUsers lists all users (admin mode only) -// Returns (result_map, error) - result_map is non-nil for benchmark mode -func (c *RAGFlowClient) ListUsers(cmd *Command) (map[string]interface{}, error) { - if c.ServerType != "admin" { - return nil, fmt.Errorf("this command is only allowed in ADMIN mode") +func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { + switch cmd.Type { + case "register_user": + return c.RegisterUser(cmd) + case "login_user": + return nil, c.LoginUser(cmd) + case "ping": + return c.PingServer(cmd) + case "benchmark": + return c.RunBenchmark(cmd) + case "list_user_datasets": + return c.ListUserDatasets(cmd) + case "search_on_datasets": + return c.SearchOnDatasets(cmd) + case "create_token": + return c.CreateToken(cmd) + case "list_tokens": + return c.ListTokens(cmd) + case "drop_token": + return c.DropToken(cmd) + case "set_token": + return c.SetToken(cmd) + case "show_token": + return c.ShowToken(cmd) + case "unset_token": + return c.UnsetToken(cmd) + case "show_version": + return c.ShowServerVersion(cmd) + // TODO: Implement other commands + default: + return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type) } - - // Check for benchmark iterations - iterations := 1 - if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { - iterations = val - } - - if iterations > 1 { - // Benchmark mode - return raw result for benchmark stats - return c.HTTPClient.RequestWithIterations("GET", "/admin/users", true, "admin", nil, nil, iterations) - } - - resp, err := c.HTTPClient.Request("GET", "/admin/users", true, "admin", nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to list users: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("failed to list users: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return nil, fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return nil, fmt.Errorf("failed to list users: %s", msg) - } - - data, ok := resJSON["data"].([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid response format") - } - - // Convert to slice of maps and remove sensitive fields - tableData := make([]map[string]interface{}, 0, len(data)) - for _, item := range data { - if itemMap, ok := item.(map[string]interface{}); ok { - // Remove sensitive fields - delete(itemMap, "password") - delete(itemMap, "access_token") - tableData = append(tableData, itemMap) - } - } - - PrintTableSimple(tableData) - return nil, nil -} - -// GrantAdmin grants admin privileges to a user (admin mode only) -func (c *RAGFlowClient) GrantAdmin(cmd *Command) error { - if c.ServerType != "admin" { - return fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return fmt.Errorf("user_name not provided") - } - - resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/admin", userName), true, "admin", nil, nil) - if err != nil { - return fmt.Errorf("failed to grant admin: %w", err) - } - - if resp.StatusCode != 200 { - return fmt.Errorf("failed to grant admin: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return fmt.Errorf("failed to grant admin: %s", msg) - } - - fmt.Printf("Admin role granted to user: %s\n", userName) - return nil -} - -// RevokeAdmin revokes admin privileges from a user (admin mode only) -func (c *RAGFlowClient) RevokeAdmin(cmd *Command) error { - if c.ServerType != "admin" { - return fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return fmt.Errorf("user_name not provided") - } - - resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/users/%s/admin", userName), true, "admin", nil, nil) - if err != nil { - return fmt.Errorf("failed to revoke admin: %w", err) - } - - if resp.StatusCode != 200 { - return fmt.Errorf("failed to revoke admin: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return fmt.Errorf("failed to revoke admin: %s", msg) - } - - fmt.Printf("Admin role revoked from user: %s\n", userName) - return nil } // ShowCurrentUser shows the current logged-in user information @@ -803,188 +387,168 @@ func (c *RAGFlowClient) ShowCurrentUser(cmd *Command) (map[string]interface{}, e return nil, fmt.Errorf("command 'SHOW CURRENT USER' is not yet implemented") } -// CreateUser creates a new user (admin mode only) -func (c *RAGFlowClient) CreateUser(cmd *Command) error { - if c.ServerType != "admin" { - return fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return fmt.Errorf("user_name not provided") - } - - password, ok := cmd.Params["password"].(string) - if !ok { - return fmt.Errorf("password not provided") - } - - // Encrypt password using RSA - encryptedPassword, err := EncryptPassword(password) - if err != nil { - return fmt.Errorf("failed to encrypt password: %w", err) - } - - payload := map[string]interface{}{ - "username": userName, - "password": encryptedPassword, - "role": "user", - } - - resp, err := c.HTTPClient.Request("POST", "/admin/users", true, "admin", nil, payload) - if err != nil { - return fmt.Errorf("failed to create user: %w", err) - } - - if resp.StatusCode != 200 { - return fmt.Errorf("failed to create user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return fmt.Errorf("failed to create user: %s", msg) - } - - fmt.Printf("User created successfully: %s\n", userName) - return nil +type ResponseIf interface { + Type() string + PrintOut() + TimeCost() float64 } -// ActivateUser activates or deactivates a user (admin mode only) -func (c *RAGFlowClient) ActivateUser(cmd *Command) error { - if c.ServerType != "admin" { - return fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return fmt.Errorf("user_name not provided") - } - - activateStatus, ok := cmd.Params["activate_status"].(string) - if !ok { - return fmt.Errorf("activate_status not provided") - } - - // Validate activate_status - if activateStatus != "on" && activateStatus != "off" { - return fmt.Errorf("activate_status must be 'on' or 'off'") - } - - payload := map[string]interface{}{ - "activate_status": activateStatus, - } - - resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/activate", userName), true, "admin", nil, payload) - if err != nil { - return fmt.Errorf("failed to update user activate status: %w", err) - } - - if resp.StatusCode != 200 { - return fmt.Errorf("failed to update user activate status: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return fmt.Errorf("failed to update user activate status: %s", msg) - } - - fmt.Printf("User '%s' activate status set to '%s'\n", userName, activateStatus) - return nil +type CommonResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 } -// AlterUserPassword changes a user's password (admin mode only) -func (c *RAGFlowClient) AlterUserPassword(cmd *Command) error { - if c.ServerType != "admin" { - return fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return fmt.Errorf("user_name not provided") - } - - password, ok := cmd.Params["password"].(string) - if !ok { - return fmt.Errorf("password not provided") - } - - // Encrypt password using RSA - encryptedPassword, err := EncryptPassword(password) - if err != nil { - return fmt.Errorf("failed to encrypt password: %w", err) - } - - payload := map[string]interface{}{ - "new_password": encryptedPassword, - } - - resp, err := c.HTTPClient.Request("PUT", fmt.Sprintf("/admin/users/%s/password", userName), true, "admin", nil, payload) - if err != nil { - return fmt.Errorf("failed to change user password: %w", err) - } - - if resp.StatusCode != 200 { - return fmt.Errorf("failed to change user password: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return fmt.Errorf("failed to change user password: %s", msg) - } - - fmt.Printf("Password changed for user: %s\n", userName) - return nil +func (r *CommonResponse) Type() string { + return "common" } -// DropUser deletes a user (admin mode only) -func (c *RAGFlowClient) DropUser(cmd *Command) error { - if c.ServerType != "admin" { - return fmt.Errorf("this command is only allowed in ADMIN mode") - } - - userName, ok := cmd.Params["user_name"].(string) - if !ok { - return fmt.Errorf("user_name not provided") - } - - resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/admin/users/%s", userName), true, "admin", nil, nil) - if err != nil { - return fmt.Errorf("failed to delete user: %w", err) - } - - if resp.StatusCode != 200 { - return fmt.Errorf("failed to delete user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) - } - - resJSON, err := resp.JSON() - if err != nil { - return fmt.Errorf("invalid JSON response: %w", err) - } - - code, ok := resJSON["code"].(float64) - if !ok || code != 0 { - msg, _ := resJSON["message"].(string) - return fmt.Errorf("failed to delete user: %s", msg) - } - - fmt.Printf("User deleted: %s\n", userName) - return nil +func (r *CommonResponse) TimeCost() float64 { + return r.Duration +} + +func (r *CommonResponse) PrintOut() { + if r.Code == 0 { + PrintTableSimple(r.Data) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type CommonDataResponse struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 +} + +func (r *CommonDataResponse) Type() string { + return "show" +} + +func (r *CommonDataResponse) TimeCost() float64 { + return r.Duration +} + +func (r *CommonDataResponse) PrintOut() { + if r.Code == 0 { + table := make([]map[string]interface{}, 0) + table = append(table, r.Data) + PrintTableSimple(table) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type SimpleResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 +} + +func (r *SimpleResponse) Type() string { + return "simple" +} + +func (r *SimpleResponse) TimeCost() float64 { + return r.Duration +} + +func (r *SimpleResponse) PrintOut() { + if r.Code == 0 { + fmt.Println("SUCCESS") + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type RegisterResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 +} + +func (r *RegisterResponse) Type() string { + return "register" +} + +func (r *RegisterResponse) TimeCost() float64 { + return r.Duration +} + +func (r *RegisterResponse) PrintOut() { + if r.Code == 0 { + fmt.Println("Register successfully") + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type BenchmarkResponse struct { + Code int `json:"code"` + Duration float64 `json:"duration"` + SuccessCount int `json:"success_count"` + FailureCount int `json:"failure_count"` + Concurrency int +} + +func (r *BenchmarkResponse) Type() string { + return "benchmark" +} + +func (r *BenchmarkResponse) PrintOut() { + if r.Code != 0 { + fmt.Printf("ERROR, Code: %d\n", r.Code) + return + } + + iterations := r.SuccessCount + r.FailureCount + if r.Concurrency == 1 { + if iterations == 1 { + fmt.Printf("Latency: %fs\n", r.Duration) + } else { + fmt.Printf("Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount) + } + } else { + fmt.Printf("Concurrency: %d, Latency: %fs, QPS: %.1f, SUCCESS: %d, FAILURE: %d\n", r.Concurrency, r.Duration, float64(iterations)/r.Duration, r.SuccessCount, r.FailureCount) + } +} + +func (r *BenchmarkResponse) TimeCost() float64 { + return r.Duration +} + +type KeyValueResponse struct { + Code int `json:"code"` + Key string `json:"key"` + Value string `json:"data"` + Duration float64 +} + +func (r *KeyValueResponse) Type() string { + return "data" +} + +func (r *KeyValueResponse) TimeCost() float64 { + return r.Duration +} + +func (r *KeyValueResponse) PrintOut() { + if r.Code == 0 { + table := make([]map[string]interface{}, 0) + // insert r.key and r.value into table + table = append(table, map[string]interface{}{ + "key": r.Key, + "value": r.Value, + }) + PrintTableSimple(table) + } else { + fmt.Println("ERROR") + fmt.Printf("%d\n", r.Code) + } } diff --git a/internal/cli/http_client.go b/internal/cli/http_client.go index eb08b4ff6..0b8b78afd 100644 --- a/internal/cli/http_client.go +++ b/internal/cli/http_client.go @@ -31,12 +31,13 @@ type HTTPClient struct { Host string Port int APIVersion string - APIKey string + APIToken string LoginToken string ConnectTimeout time.Duration ReadTimeout time.Duration VerifySSL bool client *http.Client + useAPIToken bool } // NewHTTPClient creates a new HTTP client @@ -85,8 +86,8 @@ func (c *HTTPClient) Headers(authKind string, extra map[string]string) map[strin headers := make(map[string]string) switch authKind { case "api": - if c.APIKey != "" { - headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIKey) + if c.APIToken != "" { + headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIToken) } case "web", "admin": if c.LoginToken != "" { @@ -104,6 +105,7 @@ type Response struct { StatusCode int Body []byte Headers http.Header + Duration float64 } // JSON parses the response body as JSON @@ -142,11 +144,14 @@ func (c *HTTPClient) Request(method, path string, useAPIBase bool, authKind stri req.Header.Set(k, v) } - resp, err := c.client.Do(req) + var resp *http.Response + startTime := time.Now() + resp, err = c.client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + duration := time.Since(startTime).Seconds() respBody, err := io.ReadAll(resp.Body) if err != nil { @@ -157,21 +162,93 @@ func (c *HTTPClient) Request(method, path string, useAPIBase bool, authKind stri StatusCode: resp.StatusCode, Body: respBody, Headers: resp.Header.Clone(), + Duration: duration, + }, nil +} + +// Request makes an HTTP request +func (c *HTTPClient) RequestWith2URL(method, webPath string, apiPath string, headers map[string]string, jsonBody map[string]interface{}) (*Response, error) { + var path string + var useAPIBase bool + var authKind string + if c.useAPIToken { + path = apiPath + useAPIBase = true + authKind = "api" + } else { + path = webPath + useAPIBase = false + authKind = "web" + } + + url := c.BuildURL(path, useAPIBase) + mergedHeaders := c.Headers(authKind, headers) + + var body io.Reader + if jsonBody != nil { + jsonData, err := json.Marshal(jsonBody) + if err != nil { + return nil, err + } + body = bytes.NewReader(jsonData) + if mergedHeaders == nil { + mergedHeaders = make(map[string]string) + } + mergedHeaders["Content-Type"] = "application/json" + } + + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + for k, v := range mergedHeaders { + req.Header.Set(k, v) + } + + var resp *http.Response + startTime := time.Now() + resp, err = c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + duration := time.Since(startTime).Seconds() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return &Response{ + StatusCode: resp.StatusCode, + Body: respBody, + Headers: resp.Header.Clone(), + Duration: duration, }, nil } // RequestWithIterations makes multiple HTTP requests for benchmarking // Returns a map with "duration" (total time in seconds) and "response_list" -func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}, iterations int) (map[string]interface{}, error) { +func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}, iterations int) (*BenchmarkResponse, error) { + response := new(BenchmarkResponse) + if iterations <= 1 { + start := time.Now() resp, err := c.Request(method, path, useAPIBase, authKind, headers, jsonBody) + totalDuration := time.Since(start).Seconds() if err != nil { return nil, err } - return map[string]interface{}{ - "duration": 0.0, - "response_list": []*Response{resp}, - }, nil + + response.Code = resp.StatusCode + response.Duration = totalDuration + if response.Code == 0 { + response.SuccessCount = 1 + } else { + response.FailureCount = 1 + } + return response, nil } url := c.BuildURL(path, useAPIBase) @@ -232,10 +309,17 @@ func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool, totalDuration += time.Since(start).Seconds() } - return map[string]interface{}{ - "duration": totalDuration, - "response_list": responseList, - }, nil + response.Code = 0 + response.Duration = totalDuration + for _, resp := range responseList { + if resp.StatusCode == 200 { + response.SuccessCount++ + } else { + response.FailureCount++ + } + } + + return response, nil } // RequestJSON makes an HTTP request and returns JSON response diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index 214285b65..1b4ad20dc 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -215,6 +215,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenOn, Value: ident} case "SET": return Token{Type: TokenSet, Value: ident} + case "UNSET": + return Token{Type: TokenUnset, Value: ident} case "RESET": return Token{Type: TokenReset, Value: ident} case "VERSION": @@ -287,6 +289,10 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenBenchmark, Value: ident} case "PING": return Token{Type: TokenPing, Value: ident} + case "TOKEN": + return Token{Type: TokenToken, Value: ident} + case "TOKENS": + return Token{Type: TokenTokens, Value: ident} default: return Token{Type: TokenIdentifier, Value: ident} } diff --git a/internal/cli/parser.go b/internal/cli/parser.go index 7c839d1bb..397cf2d69 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -102,6 +102,8 @@ func (p *Parser) parseSQLCommand() (*Command, error) { return p.parseRevokeCommand() case TokenSet: return p.parseSetCommand() + case TokenUnset: + return p.parseUnsetCommand() case TokenReset: return p.parseResetCommand() case TokenGenerate: @@ -189,18 +191,20 @@ func (p *Parser) parseLoginUser() (*Command, error) { cmd.Params["email"] = email p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } func (p *Parser) parsePingServer() (*Command, error) { - cmd := NewCommand("ping_server") + cmd := NewCommand("ping") p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -208,7 +212,6 @@ func (p *Parser) parsePingServer() (*Command, error) { func (p *Parser) parseRegisterCommand() (*Command, error) { cmd := NewCommand("register_user") - p.nextToken() // consume REGISTER if err := p.expectPeek(TokenUser); err != nil { return nil, err } @@ -245,8 +248,9 @@ func (p *Parser) parseRegisterCommand() (*Command, error) { cmd.Params["password"] = password p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil @@ -258,38 +262,51 @@ func (p *Parser) parseListCommand() (*Command, error) { switch p.curToken.Type { case TokenServices: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_services"), nil case TokenUsers: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_users"), nil case TokenRoles: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_roles"), nil case TokenVars: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_variables"), nil case TokenConfigs: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_configs"), nil + case TokenTokens: + p.nextToken() + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return NewCommand("list_tokens"), nil case TokenEnvs: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_environments"), nil case TokenDatasets: @@ -304,8 +321,9 @@ func (p *Parser) parseListCommand() (*Command, error) { return p.parseListDefaultModels() case TokenChats: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_user_chats"), nil case TokenFiles: @@ -334,8 +352,9 @@ func (p *Parser) parseListDatasets() (*Command, error) { p.nextToken() } - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -361,8 +380,9 @@ func (p *Parser) parseListAgents() (*Command, error) { cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -379,12 +399,13 @@ func (p *Parser) parseListKeys() (*Command, error) { return nil, err } - cmd := NewCommand("list_keys") + cmd := NewCommand("list_tokens") cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -395,8 +416,9 @@ func (p *Parser) parseListModelProviders() (*Command, error) { return nil, fmt.Errorf("expected PROVIDERS") } p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_user_model_providers"), nil } @@ -407,8 +429,9 @@ func (p *Parser) parseListDefaultModels() (*Command, error) { return nil, fmt.Errorf("expected MODELS") } p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("list_user_default_models"), nil } @@ -433,8 +456,9 @@ func (p *Parser) parseListFiles() (*Command, error) { cmd.Params["dataset_name"] = datasetName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -445,18 +469,27 @@ func (p *Parser) parseShowCommand() (*Command, error) { switch p.curToken.Type { case TokenVersion: p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("show_version"), nil + case TokenToken: + p.nextToken() + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return NewCommand("show_token"), nil case TokenCurrent: p.nextToken() if p.curToken.Type != TokenUser { return nil, fmt.Errorf("expected USER after CURRENT") } p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return NewCommand("show_current_user"), nil case TokenUser: @@ -485,8 +518,9 @@ func (p *Parser) parseShowUser() (*Command, error) { cmd := NewCommand("show_user_permission") cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -500,8 +534,9 @@ func (p *Parser) parseShowUser() (*Command, error) { cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -517,8 +552,9 @@ func (p *Parser) parseShowRole() (*Command, error) { cmd.Params["role_name"] = roleName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -534,8 +570,9 @@ func (p *Parser) parseShowVariable() (*Command, error) { cmd.Params["var_name"] = varName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -551,8 +588,9 @@ func (p *Parser) parseShowService() (*Command, error) { cmd.Params["number"] = serviceNum p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -571,11 +609,24 @@ func (p *Parser) parseCreateCommand() (*Command, error) { return p.parseCreateDataset() case TokenChat: return p.parseCreateChat() + case TokenToken: + return p.parseCreateToken() default: return nil, fmt.Errorf("unknown CREATE target: %s", p.curToken.Value) } } +func (p *Parser) parseCreateToken() (*Command, error) { + p.nextToken() // consume TOKEN + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return NewCommand("create_token"), nil +} + func (p *Parser) parseCreateUser() (*Command, error) { p.nextToken() // consume USER userName, err := p.parseQuotedString() @@ -595,8 +646,9 @@ func (p *Parser) parseCreateUser() (*Command, error) { cmd.Params["role"] = "user" p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -622,8 +674,9 @@ func (p *Parser) parseCreateRole() (*Command, error) { p.nextToken() } - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -651,8 +704,9 @@ func (p *Parser) parseCreateModelProvider() (*Command, error) { cmd.Params["provider_key"] = providerKey p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -704,8 +758,9 @@ func (p *Parser) parseCreateDataset() (*Command, error) { return nil, fmt.Errorf("expected PARSER or PIPELINE") } - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -721,8 +776,9 @@ func (p *Parser) parseCreateChat() (*Command, error) { cmd.Params["chat_name"] = chatName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -743,11 +799,32 @@ func (p *Parser) parseDropCommand() (*Command, error) { return p.parseDropChat() case TokenKey: return p.parseDropKey() + case TokenToken: + return p.parseDropToken() default: return nil, fmt.Errorf("unknown DROP target: %s", p.curToken.Value) } } +func (p *Parser) parseDropToken() (*Command, error) { + p.nextToken() // consume TOKEN + + tokenValue, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_token") + cmd.Params["token"] = tokenValue + + p.nextToken() + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + func (p *Parser) parseDropUser() (*Command, error) { p.nextToken() // consume USER userName, err := p.parseQuotedString() @@ -759,8 +836,9 @@ func (p *Parser) parseDropUser() (*Command, error) { cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -776,8 +854,9 @@ func (p *Parser) parseDropRole() (*Command, error) { cmd.Params["role_name"] = roleName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -798,8 +877,9 @@ func (p *Parser) parseDropModelProvider() (*Command, error) { cmd.Params["provider_name"] = providerName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -815,8 +895,9 @@ func (p *Parser) parseDropDataset() (*Command, error) { cmd.Params["dataset_name"] = datasetName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -832,15 +913,16 @@ func (p *Parser) parseDropChat() (*Command, error) { cmd.Params["chat_name"] = chatName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } func (p *Parser) parseDropKey() (*Command, error) { p.nextToken() // consume KEY - key, err := p.parseQuotedString() + token, err := p.parseQuotedString() if err != nil { return nil, err } @@ -856,13 +938,14 @@ func (p *Parser) parseDropKey() (*Command, error) { return nil, err } - cmd := NewCommand("drop_key") - cmd.Params["key"] = key + cmd := NewCommand("drop_token") + cmd.Params["token"] = token cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -905,8 +988,9 @@ func (p *Parser) parseAlterUser() (*Command, error) { cmd.Params["password"] = password p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for SHOW TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -936,8 +1020,9 @@ func (p *Parser) parseAlterUser() (*Command, error) { cmd.Params["role_name"] = roleName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -954,16 +1039,17 @@ func (p *Parser) parseActivateUser() (*Command, error) { status := p.curToken.Value if status != "on" && status != "off" { return nil, fmt.Errorf("expected 'on' or 'off', got %s", p.curToken.Value) - } + } cmd := NewCommand("activate_user") cmd.Params["user_name"] = userName cmd.Params["activate_status"] = status p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err - } + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } return cmd, nil } @@ -994,8 +1080,9 @@ func (p *Parser) parseAlterRole() (*Command, error) { cmd.Params["description"] = description p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1021,8 +1108,9 @@ func (p *Parser) parseGrantAdmin() (*Command, error) { cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1064,8 +1152,9 @@ func (p *Parser) parseGrantPermission() (*Command, error) { cmd.Params["role_name"] = roleName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1091,8 +1180,9 @@ func (p *Parser) parseRevokeAdmin() (*Command, error) { cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1134,8 +1224,9 @@ func (p *Parser) parseRevokePermission() (*Command, error) { cmd.Params["role_name"] = roleName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1172,6 +1263,9 @@ func (p *Parser) parseSetCommand() (*Command, error) { if p.curToken.Type == TokenDefault { return p.parseSetDefault() } + if p.curToken.Type == TokenToken { + return p.parseSetToken() + } return nil, fmt.Errorf("unknown SET target: %s", p.curToken.Value) } @@ -1194,8 +1288,9 @@ func (p *Parser) parseSetVariable() (*Command, error) { cmd.Params["var_value"] = varValue p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1234,9 +1329,29 @@ func (p *Parser) parseSetDefault() (*Command, error) { cmd.Params["model_id"] = modelID p.nextToken() - if err := p.expectSemicolon(); err != nil { + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + +func (p *Parser) parseSetToken() (*Command, error) { + p.nextToken() // consume TOKEN + + tokenValue, err := p.parseQuotedString() + if err != nil { return nil, err } + + cmd := NewCommand("set_token") + cmd.Params["token"] = tokenValue + + p.nextToken() + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } return cmd, nil } @@ -1270,8 +1385,9 @@ func (p *Parser) parseResetCommand() (*Command, error) { cmd.Params["model_type"] = modelType p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1296,12 +1412,13 @@ func (p *Parser) parseGenerateCommand() (*Command, error) { return nil, err } - cmd := NewCommand("generate_key") + cmd := NewCommand("generate_token") cmd.Params["user_name"] = userName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1333,8 +1450,9 @@ func (p *Parser) parseImportCommand() (*Command, error) { cmd.Params["dataset_name"] = datasetName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1366,8 +1484,9 @@ func (p *Parser) parseSearchCommand() (*Command, error) { cmd.Params["datasets"] = datasets p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1404,8 +1523,9 @@ func (p *Parser) parseParseDataset() (*Command, error) { cmd.Params["method"] = method p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1436,8 +1556,9 @@ func (p *Parser) parseParseDocs() (*Command, error) { cmd.Params["dataset_name"] = datasetName p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1482,6 +1603,8 @@ func (p *Parser) parseUserStatement() (*Command, error) { return p.parseDropCommand() case TokenSet: return p.parseSetCommand() + case TokenUnset: + return p.parseUnsetCommand() case TokenReset: return p.parseResetCommand() case TokenList: @@ -1513,8 +1636,9 @@ func (p *Parser) parseStartupCommand() (*Command, error) { cmd.Params["number"] = serviceNum p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1535,8 +1659,9 @@ func (p *Parser) parseShutdownCommand() (*Command, error) { cmd.Params["number"] = serviceNum p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } @@ -1557,12 +1682,28 @@ func (p *Parser) parseRestartCommand() (*Command, error) { cmd.Params["number"] = serviceNum p.nextToken() - if err := p.expectSemicolon(); err != nil { - return nil, err + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() } return cmd, nil } +func (p *Parser) parseUnsetCommand() (*Command, error) { + p.nextToken() // consume UNSET + + if p.curToken.Type != TokenToken { + return nil, fmt.Errorf("expected TOKEN after UNSET") + } + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return NewCommand("unset_token"), nil +} + func tokenTypeToString(t int) string { // Simplified for error messages return fmt.Sprintf("token(%d)", t) diff --git a/internal/cli/types.go b/internal/cli/types.go index b9d11b8b3..06c646eca 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -95,6 +95,9 @@ const ( TokenSync TokenBenchmark TokenPing + TokenToken + TokenTokens + TokenUnset // Literals TokenIdentifier diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go new file mode 100644 index 000000000..32f39c06e --- /dev/null +++ b/internal/cli/user_command.go @@ -0,0 +1,548 @@ +package cli + +import ( + "encoding/json" + "fmt" + "strings" +) + +// PingServer pings the server to check if it's alive +// Returns benchmark result map if iterations > 1, otherwise prints status +func (c *RAGFlowClient) PingServer(cmd *Command) (ResponseIf, error) { + // Get iterations from command params (for benchmark) + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode: multiple iterations + return c.HTTPClient.RequestWithIterations("GET", "/system/ping", false, "web", nil, nil, iterations) + } + + // Single mode + resp, err := c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Server is down") + return nil, err + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ping: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + result.Message = string(resp.Body) + result.Code = 0 + return &result, nil +} + +// Show server version to show RAGFlow server version +// Returns benchmark result map if iterations > 1, otherwise prints status +func (c *RAGFlowClient) ShowServerVersion(cmd *Command) (ResponseIf, error) { + // Get iterations from command params (for benchmark) + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode: multiple iterations + return c.HTTPClient.RequestWithIterations("GET", "/system/version", false, "web", nil, nil, iterations) + } + + // Single mode + resp, err := c.HTTPClient.Request("GET", "/system/version", false, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show version: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show version: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result KeyValueResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show version failed: invalid JSON (%w)", err) + } + result.Key = "version" + result.Duration = resp.Duration + + return &result, nil +} + +func (c *RAGFlowClient) RegisterUser(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + // Check for benchmark iterations + var ok bool + _, ok = cmd.Params["iterations"].(int) + if ok { + return nil, fmt.Errorf("failed to register user in benchmark statement") + } + + var email string + email, ok = cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("no email") + } + + var password string + password, ok = cmd.Params["password"].(string) + if !ok { + return nil, fmt.Errorf("no password") + } + + var nickname string + nickname, ok = cmd.Params["nickname"].(string) + if !ok { + return nil, fmt.Errorf("no nickname") + } + + payload := map[string]interface{}{ + "email": email, + "password": password, + "nickname": nickname, + } + + resp, err := c.HTTPClient.Request("POST", "/user/register", false, "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to register user: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to register user: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result RegisterResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("register user failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + +// ListUserDatasets lists datasets for current user (user mode) +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) ListUserDatasets(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", "/datasets", true, "web", nil, nil, iterations) + } + + // Normal mode + resp, err := c.HTTPClient.Request("GET", "/datasets", true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list datasets: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list users failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + +// getDatasetID gets dataset ID by name +func (c *RAGFlowClient) getDatasetID(datasetName string) (string, error) { + resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) + if err != nil { + return "", fmt.Errorf("failed to list datasets: %w", err) + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("failed to list datasets: HTTP %d", resp.StatusCode) + } + + resJSON, err := resp.JSON() + if err != nil { + return "", fmt.Errorf("invalid JSON response: %w", err) + } + + code, ok := resJSON["code"].(float64) + if !ok || code != 0 { + msg, _ := resJSON["message"].(string) + return "", fmt.Errorf("failed to list datasets: %s", msg) + } + + data, ok := resJSON["data"].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid response format") + } + + kbs, ok := data["kbs"].([]interface{}) + if !ok { + return "", fmt.Errorf("invalid response format: kbs not found") + } + + for _, kb := range kbs { + if kbMap, ok := kb.(map[string]interface{}); ok { + if name, _ := kbMap["name"].(string); name == datasetName { + if id, _ := kbMap["id"].(string); id != "" { + return id, nil + } + } + } + } + + return "", fmt.Errorf("dataset '%s' not found", datasetName) +} + +// formatEmptyArray converts empty arrays to "[]" string +func formatEmptyArray(v interface{}) string { + if v == nil { + return "[]" + } + switch val := v.(type) { + case []interface{}: + if len(val) == 0 { + return "[]" + } + case []string: + if len(val) == 0 { + return "[]" + } + case []int: + if len(val) == 0 { + return "[]" + } + } + return fmt.Sprintf("%v", v) +} + +// SearchOnDatasets searches for chunks in specified datasets +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) SearchOnDatasets(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + question, ok := cmd.Params["question"].(string) + if !ok { + return nil, fmt.Errorf("question not provided") + } + + datasets, ok := cmd.Params["datasets"].(string) + if !ok { + return nil, fmt.Errorf("datasets not provided") + } + + // Parse dataset names (comma-separated) and convert to IDs + datasetNames := strings.Split(datasets, ",") + datasetIDs := make([]string, 0, len(datasetNames)) + for _, name := range datasetNames { + name = strings.TrimSpace(name) + id, err := c.getDatasetID(name) + if err != nil { + return nil, err + } + datasetIDs = append(datasetIDs, id) + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + payload := map[string]interface{}{ + "kb_id": datasetIDs, + "question": question, + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("POST", "/chunk/retrieval_test", false, "web", nil, payload, iterations) + } + + // Normal mode + resp, err := c.HTTPClient.Request("POST", "/chunk/retrieval_test", false, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to search on datasets: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to search on datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + resJSON, err := resp.JSON() + if err != nil { + return nil, fmt.Errorf("invalid JSON response: %w", err) + } + + code, ok := resJSON["code"].(float64) + if !ok || code != 0 { + msg, _ := resJSON["message"].(string) + return nil, fmt.Errorf("failed to search on datasets: %s", msg) + } + + data, ok := resJSON["data"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format") + } + + chunks, ok := data["chunks"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format: chunks not found") + } + + // Convert to slice of maps for printing + tableData := make([]map[string]interface{}, 0, len(chunks)) + for _, chunk := range chunks { + if chunkMap, ok := chunk.(map[string]interface{}); ok { + row := map[string]interface{}{ + "id": chunkMap["chunk_id"], + "content": chunkMap["content_with_weight"], + "document_id": chunkMap["doc_id"], + "dataset_id": chunkMap["kb_id"], + "docnm_kwd": chunkMap["docnm_kwd"], + "image_id": chunkMap["image_id"], + "similarity": chunkMap["similarity"], + "term_similarity": chunkMap["term_similarity"], + "vector_similarity": chunkMap["vector_similarity"], + } + // Add optional fields that may be empty arrays + if v, ok := chunkMap["doc_type_kwd"]; ok { + row["doc_type_kwd"] = formatEmptyArray(v) + } + if v, ok := chunkMap["important_kwd"]; ok { + row["important_kwd"] = formatEmptyArray(v) + } + if v, ok := chunkMap["mom_id"]; ok { + row["mom_id"] = formatEmptyArray(v) + } + if v, ok := chunkMap["positions"]; ok { + row["positions"] = formatEmptyArray(v) + } + if v, ok := chunkMap["content_ltks"]; ok { + row["content_ltks"] = v + } + tableData = append(tableData, row) + } + } + + PrintTableSimple(tableData) + return nil, nil +} + +// CreateToken creates a new API token +func (c *RAGFlowClient) CreateToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + resp, err := c.HTTPClient.Request("POST", "/tokens", true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to create token: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to create token: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var createResult CommonDataResponse + if err = json.Unmarshal(resp.Body, &createResult); err != nil { + return nil, fmt.Errorf("create token failed: invalid JSON (%w)", err) + } + + if createResult.Code != 0 { + return nil, fmt.Errorf("%s", createResult.Message) + } + + var result SimpleResponse + result.Code = 0 + result.Message = "Token created successfully" + result.Duration = resp.Duration + return &result, nil +} + +// ListTokens lists all API tokens for the current user +func (c *RAGFlowClient) ListTokens(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + resp, err := c.HTTPClient.Request("GET", "/tokens", true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list tokens: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list tokens: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list tokens failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// DropToken deletes an API token +func (c *RAGFlowClient) DropToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + token, ok := cmd.Params["token"].(string) + if !ok { + return nil, fmt.Errorf("token not provided") + } + + resp, err := c.HTTPClient.Request("DELETE", fmt.Sprintf("/tokens/%s", token), true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to drop token: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to drop token: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("drop token failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +// SetToken sets the API token after validating it +func (c *RAGFlowClient) SetToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + token, ok := cmd.Params["token"].(string) + if !ok { + return nil, fmt.Errorf("token not provided") + } + + // Save current token to restore if validation fails + savedToken := c.HTTPClient.APIToken + savedUseAPIToken := c.HTTPClient.useAPIToken + + // Set the new token temporarily for validation + c.HTTPClient.APIToken = token + c.HTTPClient.useAPIToken = true + + // Validate token by calling list tokens API + resp, err := c.HTTPClient.Request("GET", "/tokens", true, "api", nil, nil) + if err != nil { + // Restore original token on error + c.HTTPClient.APIToken = savedToken + c.HTTPClient.useAPIToken = savedUseAPIToken + return nil, fmt.Errorf("failed to validate token: %w", err) + } + + if resp.StatusCode != 200 { + // Restore original token on error + c.HTTPClient.APIToken = savedToken + c.HTTPClient.useAPIToken = savedUseAPIToken + return nil, fmt.Errorf("token validation failed: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + // Restore original token on error + c.HTTPClient.APIToken = savedToken + c.HTTPClient.useAPIToken = savedUseAPIToken + return nil, fmt.Errorf("token validation failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + // Restore original token on error + c.HTTPClient.APIToken = savedToken + c.HTTPClient.useAPIToken = savedUseAPIToken + return nil, fmt.Errorf("token validation failed: %s", result.Message) + } + + // Token is valid, keep it set + var successResult SimpleResponse + successResult.Code = 0 + successResult.Message = "API token set successfully" + successResult.Duration = resp.Duration + return &successResult, nil +} + +// ShowToken displays the current API token +func (c *RAGFlowClient) ShowToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + if c.HTTPClient.APIToken == "" { + return nil, fmt.Errorf("no API token is currently set") + } + + //fmt.Printf("Token: %s\n", c.HTTPClient.APIToken) + + var result CommonResponse + result.Code = 0 + result.Message = "" + result.Data = []map[string]interface{}{ + { + "token": c.HTTPClient.APIToken, + }, + } + result.Duration = 0 + return &result, nil +} + +// UnsetToken removes the current API token +func (c *RAGFlowClient) UnsetToken(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + if c.HTTPClient.APIToken == "" { + return nil, fmt.Errorf("no API token is currently set") + } + + c.HTTPClient.APIToken = "" + c.HTTPClient.useAPIToken = false + + var result SimpleResponse + result.Code = 0 + result.Message = "API token unset successfully" + result.Duration = 0 + return &result, nil +} diff --git a/internal/dao/api_token.go b/internal/dao/api_token.go index 6db2ce76c..152ee327b 100644 --- a/internal/dao/api_token.go +++ b/internal/dao/api_token.go @@ -46,6 +46,16 @@ func (dao *APITokenDAO) DeleteByTenantID(tenantID string) (int64, error) { return result.RowsAffected, result.Error } +// GetByToken gets API token by access key +func (dao *APITokenDAO) GetUserByAPIToken(token string) (*model.APIToken, error) { + var apiToken model.APIToken + err := DB.Where("token = ?", token).First(&apiToken).Error + if err != nil { + return nil, err + } + return &apiToken, nil +} + // DeleteByDialogIDs deletes API tokens by dialog IDs (hard delete) func (dao *APITokenDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) { if len(dialogIDs) == 0 { @@ -55,6 +65,12 @@ func (dao *APITokenDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) { return result.RowsAffected, result.Error } +// DeleteByTenantIDAndToken deletes a specific API token by tenant ID and token value +func (dao *APITokenDAO) DeleteByTenantIDAndToken(tenantID, token string) (int64, error) { + result := DB.Unscoped().Where("tenant_id = ? AND token = ?", tenantID, token).Delete(&model.APIToken{}) + return result.RowsAffected, result.Error +} + // API4ConversationDAO API for conversation data access object type API4ConversationDAO struct{} diff --git a/internal/dao/user.go b/internal/dao/user.go index 325b9c647..d36b4f9be 100644 --- a/internal/dao/user.go +++ b/internal/dao/user.go @@ -43,6 +43,15 @@ func (dao *UserDAO) GetByID(id uint) (*model.User, error) { return &user, nil } +func (dao *UserDAO) GetByTenantID(tenantID string) (*model.User, error) { + var user model.User + err := DB.Where("id = ?", tenantID).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + // GetByUsername get user by username func (dao *UserDAO) GetByUsername(username string) (*model.User, error) { var user model.User diff --git a/internal/handler/api_token.go b/internal/handler/api_token.go new file mode 100644 index 000000000..c974a5985 --- /dev/null +++ b/internal/handler/api_token.go @@ -0,0 +1,224 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "ragflow/internal/dao" + "ragflow/internal/model" + "ragflow/internal/service" + + "github.com/gin-gonic/gin" +) + +// ListTokens list all API tokens for the current user's tenant +// @Summary List API Tokens +// @Description List all API tokens for the current user's tenant +// @Tags system +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {object} map[string]interface{} +// @Router /v1/system/token_list [get] +func (h *SystemHandler) ListTokens(c *gin.Context) { + // Get current user from context + user, exists := c.Get("user") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Unauthorized", + }) + return + } + + userModel, ok := user.(*model.User) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Invalid user data", + }) + return + } + + // Get user's tenant with owner role + userTenantDAO := dao.NewUserTenantDAO() + tenants, err := userTenantDAO.GetByUserIDAndRole(userModel.ID, "owner") + if err != nil || len(tenants) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Tenant not found", + }) + return + } + + tenantID := tenants[0].TenantID + + // Get tokens for the tenant + tokens, err := h.systemService.ListAPITokens(tenantID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Failed to list tokens", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": tokens, + }) +} + +// CreateToken creates a new API token for the current user's tenant +// @Summary Create API Token +// @Description Generate a new API token for the current user's tenant +// @Tags system +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param name query string false "Name of the token" +// @Success 200 {object} map[string]interface{} +// @Router /v1/system/new_token [post] +func (h *SystemHandler) CreateToken(c *gin.Context) { + // Get current user from context + user, exists := c.Get("user") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Unauthorized", + }) + return + } + + userModel, ok := user.(*model.User) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Invalid user data", + }) + return + } + + // Get user's tenant with owner role + userTenantDAO := dao.NewUserTenantDAO() + tenants, err := userTenantDAO.GetByUserIDAndRole(userModel.ID, "owner") + if err != nil || len(tenants) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Tenant not found", + }) + return + } + + tenantID := tenants[0].TenantID + + // Parse request + var req service.CreateAPITokenRequest + if err := c.ShouldBind(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Invalid request", + }) + return + } + + // Create token + token, err := h.systemService.CreateAPIToken(tenantID, &req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Failed to create token", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": token, + }) +} + +// DeleteToken deletes an API token +// @Summary Delete API Token +// @Description Remove an API token for the current user's tenant +// @Tags system +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param token path string true "The API token to remove" +// @Success 200 {object} map[string]interface{} +// @Router /v1/system/token/{token} [delete] +func (h *SystemHandler) DeleteToken(c *gin.Context) { + // Get current user from context + user, exists := c.Get("user") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Unauthorized", + }) + return + } + + userModel, ok := user.(*model.User) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Invalid user data", + }) + return + } + + // Get user's tenant with owner role + userTenantDAO := dao.NewUserTenantDAO() + tenants, err := userTenantDAO.GetByUserIDAndRole(userModel.ID, "owner") + if err != nil || len(tenants) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Tenant not found", + }) + return + } + + tenantID := tenants[0].TenantID + + // Get token from path parameter + token := c.Param("token") + if token == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Token is required", + }) + return + } + + // Delete token + if err := h.systemService.DeleteAPIToken(tenantID, token); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Failed to delete token", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": true, + }) +} diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 3c336b2ec..a983e9b40 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -56,12 +56,15 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { // Get user by access token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": "Invalid access token", - }) - c.Abort() - return + user, code, err = h.userService.GetUserByAPIToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": code, + "message": "Invalid access token", + }) + c.Abort() + return + } } if *user.IsSuperuser { diff --git a/internal/handler/system.go b/internal/handler/system.go index da7fe52f6..781634f21 100644 --- a/internal/handler/system.go +++ b/internal/handler/system.go @@ -19,10 +19,9 @@ package handler import ( "net/http" "ragflow/internal/server" + "ragflow/internal/service" "github.com/gin-gonic/gin" - - "ragflow/internal/service" ) // SystemHandler system handler diff --git a/internal/model/api.go b/internal/model/api_token.go similarity index 100% rename from internal/model/api.go rename to internal/model/api_token.go diff --git a/internal/model/user.go b/internal/model/user.go index 05f563351..41845fc73 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -36,6 +36,7 @@ type User struct { LoginChannel *string `gorm:"column:login_channel;index" json:"login_channel,omitempty"` Status *string `gorm:"column:status;size:1;default:1;index" json:"status"` IsSuperuser *bool `gorm:"column:is_superuser;index" json:"is_superuser,omitempty"` + RoleID int64 `gorm:"column:role_id;index;default:1;not null;" json:"role_id,omitempty"` BaseModel } diff --git a/internal/router/router.go b/internal/router/router.go index 1810cb639..8e6226fb7 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -116,6 +116,11 @@ func (r *Router) Setup(engine *gin.Engine) { // User set tenant info endpoint authorized.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo) + // System token endpoints (requires authentication) + authorized.GET("/v1/system/token_list", r.systemHandler.ListTokens) + authorized.POST("/v1/system/new_token", r.systemHandler.CreateToken) + authorized.DELETE("/v1/system/token/:token", r.systemHandler.DeleteToken) + // API v1 route group v1 := authorized.Group("/api/v1") { @@ -128,6 +133,13 @@ func (r *Router) Setup(engine *gin.Engine) { // users.GET("/:id", r.userHandler.GetUserByID) //} + apiTokens := v1.Group("/tokens") + { + apiTokens.POST("", r.systemHandler.CreateToken) + apiTokens.GET("", r.systemHandler.ListTokens) + apiTokens.DELETE("/:token", r.systemHandler.DeleteToken) + } + // Document routes documents := v1.Group("/documents") { diff --git a/internal/service/api_token.go b/internal/service/api_token.go new file mode 100644 index 000000000..931b9d55c --- /dev/null +++ b/internal/service/api_token.go @@ -0,0 +1,107 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "ragflow/internal/dao" + "ragflow/internal/model" + "ragflow/internal/utility" + "time" +) + +// TokenResponse token response +type TokenResponse struct { + TenantID string `json:"tenant_id"` + Token string `json:"token"` + DialogID *string `json:"dialog_id,omitempty"` + Source *string `json:"source,omitempty"` + Beta *string `json:"beta,omitempty"` + CreateTime *int64 `json:"create_time,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` +} + +// ListAPITokens list all API tokens for a tenant +func (s *SystemService) ListAPITokens(tenantID string) ([]*TokenResponse, error) { + APITokenDAO := dao.NewAPITokenDAO() + tokens, err := APITokenDAO.GetByTenantID(tenantID) + if err != nil { + return nil, err + } + + responses := make([]*TokenResponse, len(tokens)) + for i, token := range tokens { + responses[i] = &TokenResponse{ + TenantID: token.TenantID, + Token: token.Token, + DialogID: token.DialogID, + Source: token.Source, + Beta: token.Beta, + CreateTime: token.CreateTime, + UpdateTime: token.UpdateTime, + } + } + + return responses, nil +} + +// CreateAPITokenRequest create token request +type CreateAPITokenRequest struct { + Name string `json:"name" form:"name"` +} + +// CreateAPIToken creates a new API token for a tenant +func (s *SystemService) CreateAPIToken(tenantID string, req *CreateAPITokenRequest) (*TokenResponse, error) { + APITokenDAO := dao.NewAPITokenDAO() + + now := time.Now().Unix() + nowDate := time.Now() + + // Generate token and beta values + // token: "ragflow-" + secrets.token_urlsafe(32) + APIToken := utility.GenerateAPIToken() + // beta: generate_confirmation_token().replace("ragflow-", "")[:32] + betaAPIKey := utility.GenerateBetaAPIToken(APIToken) + + APITokenData := &model.APIToken{ + TenantID: tenantID, + Token: APIToken, + Beta: &betaAPIKey, + } + APITokenData.CreateDate = &nowDate + APITokenData.CreateTime = &now + + if err := APITokenDAO.Create(APITokenData); err != nil { + return nil, err + } + + return &TokenResponse{ + TenantID: APITokenData.TenantID, + Token: APITokenData.Token, + DialogID: APITokenData.DialogID, + Source: APITokenData.Source, + Beta: APITokenData.Beta, + CreateTime: APITokenData.CreateTime, + UpdateTime: APITokenData.UpdateTime, + }, nil +} + +// DeleteAPIToken deletes an API token by tenant ID and token value +func (s *SystemService) DeleteAPIToken(tenantID, token string) error { + APITokenDAO := dao.NewAPITokenDAO() + _, err := APITokenDAO.DeleteByTenantIDAndToken(tenantID, token) + return err +} diff --git a/internal/service/user.go b/internal/service/user.go index a0a263fa7..96544bee5 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -59,7 +59,7 @@ func NewUserService() *UserService { // RegisterRequest registration request type RegisterRequest struct { Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` + Password string `json:"password" binding:"required,min=1"` Nickname string `json:"nickname"` } @@ -122,7 +122,8 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password") } - hashedPassword, err := s.HashPassword(decryptedPassword) + var hashedPassword string + hashedPassword, err = s.HashPassword(decryptedPassword) if err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err) } @@ -227,20 +228,20 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC userTenantDAO := dao.NewUserTenantDAO() fileDAO := dao.NewFileDAO() - if err := s.userDAO.Create(user); err != nil { + if err = s.userDAO.Create(user); err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err) } - if err := tenantDAO.Create(tenant); err != nil { - err := s.userDAO.DeleteByID(userID) + if err = tenantDAO.Create(tenant); err != nil { + err = s.userDAO.DeleteByID(userID) if err != nil { return nil, 0, err } return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err) } - if err := userTenantDAO.Create(userTenant); err != nil { - err := s.userDAO.DeleteByID(userID) + if err = userTenantDAO.Create(userTenant); err != nil { + err = s.userDAO.DeleteByID(userID) if err != nil { return nil, 0, err } @@ -251,8 +252,8 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err) } - if err := fileDAO.Create(rootFile); err != nil { - err := s.userDAO.DeleteByID(userID) + if err = fileDAO.Create(rootFile); err != nil { + err = s.userDAO.DeleteByID(userID) if err != nil { return nil, 0, err } @@ -922,3 +923,44 @@ func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) er return nil } + +// GetUserByAPIToken gets user by access key from Authorization header +// This is used for API token authentication +// The authorization parameter should be in format: "Bearer " or just "" +func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, common.ErrorCode, error) { + if authorization == "" { + return nil, common.CodeUnauthorized, fmt.Errorf("authorization header is empty") + } + + // Split authorization header to get the token + // Expected format: "Bearer " or "" + parts := strings.Split(authorization, " ") + var token string + if len(parts) == 2 { + token = parts[1] + } else if len(parts) == 1 { + token = parts[0] + } else { + return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format") + } + + // Query API token from database + apiTokenDAO := dao.NewAPITokenDAO() + userToken, err := apiTokenDAO.GetUserByAPIToken(token) + if err != nil { + return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token") + } + + // Get user by tenant_id from API token + user, err := s.userDAO.GetByTenantID(userToken.TenantID) + if err != nil { + return nil, common.CodeUnauthorized, fmt.Errorf("user not found for this access token") + } + + // Check if user's access_token is empty + if user.AccessToken == nil || *user.AccessToken == "" { + return nil, common.CodeUnauthorized, fmt.Errorf("user has empty access_token in database") + } + + return user, common.CodeSuccess, nil +} diff --git a/internal/utility/token.go b/internal/utility/token.go index 3c7b97fc7..af5c0deb5 100644 --- a/internal/utility/token.go +++ b/internal/utility/token.go @@ -138,3 +138,29 @@ func GenerateSecretKey() (string, error) { func GenerateToken() string { return strings.ReplaceAll(uuid.New().String(), "-", "") } + +// GenerateAPIToken generates a secure random access key +// Equivalent to Python's generate_confirmation_token(): +// return "ragflow-" + secrets.token_urlsafe(32) +func GenerateAPIToken() string { + // Generate 32 random bytes + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + // Fallback to UUID if random generation fails + return "ragflow-" + strings.ReplaceAll(uuid.New().String(), "-", "") + } + // Use URL-safe base64 encoding (same as Python's token_urlsafe) + return "ragflow-" + base64.RawURLEncoding.EncodeToString(bytes) +} + +// GenerateBetaAPIToken generates a beta access key +// Equivalent to Python's: generate_confirmation_token().replace("ragflow-", "")[:32] +func GenerateBetaAPIToken(accessKey string) string { + // Remove "ragflow-" prefix + withoutPrefix := strings.TrimPrefix(accessKey, "ragflow-") + // Take first 32 characters + if len(withoutPrefix) > 32 { + return withoutPrefix[:32] + } + return withoutPrefix +} diff --git a/web/vite.config.ts b/web/vite.config.ts index babafe497..741806ef9 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -49,6 +49,21 @@ export default defineConfig(({ mode }) => { }, }, hybrid: { + '/v1/system/token_list': { + target: 'http://127.0.0.1:9384/', + changeOrigin: true, + ws: true, + }, + '/v1/system/new_token': { + target: 'http://127.0.0.1:9384/', + changeOrigin: true, + ws: true, + }, + '/v1/system/token': { + target: 'http://127.0.0.1:9384/', + changeOrigin: true, + ws: true, + }, '/v1/system/config': { target: 'http://127.0.0.1:9384/', changeOrigin: true,