From 6e309f9d0a64f9fa44e3421e5d96baed836f2dfd Mon Sep 17 00:00:00 2001 From: Yingfeng Date: Thu, 26 Mar 2026 21:07:06 +0800 Subject: [PATCH] Feat: Initialize context engine CLI (#13776) ### What problem does this PR solve? - Add multiple output format to ragflow_cli - Initialize contextengine to Go module - ls datasets/ls files - cat file - search -d dir -q query issue: #13714 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- internal/cli/README.md | 109 ++- internal/cli/cli.go | 778 +++++++++++++++-- internal/cli/client.go | 296 ++++++- internal/cli/contextengine/README.md | 49 ++ .../cli/contextengine/dataset_provider.go | 781 ++++++++++++++++++ internal/cli/contextengine/engine.go | 312 +++++++ internal/cli/contextengine/file_provider.go | 594 +++++++++++++ internal/cli/contextengine/provider.go | 180 ++++ internal/cli/contextengine/types.go | 116 +++ internal/cli/contextengine/utils.go | 304 +++++++ internal/cli/parser.go | 104 +++ internal/cli/table.go | 270 ++++-- internal/cli/user_command.go | 10 +- internal/cli/user_parser.go | 5 +- 14 files changed, 3753 insertions(+), 155 deletions(-) create mode 100644 internal/cli/contextengine/README.md create mode 100644 internal/cli/contextengine/dataset_provider.go create mode 100644 internal/cli/contextengine/engine.go create mode 100644 internal/cli/contextengine/file_provider.go create mode 100644 internal/cli/contextengine/provider.go create mode 100644 internal/cli/contextengine/types.go create mode 100644 internal/cli/contextengine/utils.go diff --git a/internal/cli/README.md b/internal/cli/README.md index 4f71a37de..c626b57f0 100644 --- a/internal/cli/README.md +++ b/internal/cli/README.md @@ -4,20 +4,21 @@ This is the Go implementation of the RAGFlow command-line interface, compatible ## Features -- Interactive mode only +- Interactive mode and single command execution - Full compatibility with Python CLI syntax - Recursive descent parser for SQL-like commands +- Context Engine (Virtual Filesystem) for intuitive resource management - Support for all major commands: - User management: LOGIN, REGISTER, CREATE USER, DROP USER, LIST USERS, etc. - Service management: LIST SERVICES, SHOW SERVICE, STARTUP/SHUTDOWN/RESTART SERVICE - Role management: CREATE ROLE, DROP ROLE, LIST ROLES, GRANT/REVOKE PERMISSION - - Dataset management: CREATE DATASET, DROP DATASET, LIST DATASETS + - Dataset management via Context Engine: `ls`, `search`, `mkdir`, `cat`, `rm` - Model management: SET/RESET DEFAULT LLM/VLM/EMBEDDING/etc. - And more... ## Usage -Build and run: +### Build and run ```bash go build -o ragflow_cli ./cmd/ragflow_cli.go @@ -28,11 +29,94 @@ go build -o ragflow_cli ./cmd/ragflow_cli.go ``` internal/cli/ -├── cli.go # Main CLI loop and interaction -├── parser/ # Command parser package -│ ├── types.go # Token and Command types -│ ├── lexer.go # Lexical analyzer -│ └── parser.go # Recursive descent parser +├── cli.go # Main CLI loop and interaction +├── client.go # RAGFlowClient with Context Engine integration +├── http_client.go # HTTP client for API communication +├── parser/ # Command parser package +│ ├── types.go # Token and Command types +│ ├── lexer.go # Lexical analyzer +│ └── parser.go # Recursive descent parser +└── contextengine/ # Context Engine (Virtual Filesystem) + ├── engine.go # Core engine: path resolution, command routing + ├── types.go # Node, Command, Result types + ├── provider.go # Provider interface definition + ├── dataset_provider.go # Dataset provider implementation + ├── file_provider.go # File manager provider implementation + └── utils.go # Helper functions +``` + +## Context Engine + +The Context Engine provides a unified virtual filesystem interface over RAGFlow's RESTful APIs. + +### Design Principles + +1. **No Server-Side Changes**: All logic implemented client-side using existing APIs +2. **Provider Pattern**: Modular providers for different resource types (datasets, files, etc.) +3. **Unified Interface**: Common `ls`, `search`, `mkdir` commands across all providers +4. **Path-Based Navigation**: Virtual paths like `/datasets`, `/datasets/{name}/files` + +### Supported Paths + +| Path | Description | +|------|-------------| +| `/datasets` | List all datasets | +| `/datasets/{name}` | List documents in dataset (default behavior) | +| `/datasets/{name}/{doc}` | Get document info | + +### Commands + +#### `ls [path] [options]` - List nodes at path + +List contents of a path in the context filesystem. + +**Arguments:** +- `[path]` - Path to list (default: "datasets") + +**Options:** +- `-n, --limit ` - Maximum number of items to display (default: 10) +- `-h, --help` - Show ls help message + +**Examples:** +```bash +ls # List all datasets (default 10) +ls -n 20 # List 20 datasets +ls datasets/kb1 # List files in kb1 dataset +ls datasets/kb1 -n 50 # List 50 files in kb1 dataset +``` + +#### `search [options]` - Search for content + +Semantic search in datasets. + +**Options:** +- `-d, --dir ` - Directory to search in (can be specified multiple times) +- `-q, --query ` - Search query (required) +- `-k, --top-k ` - Number of top results to return (default: 10) +- `-t, --threshold ` - Similarity threshold, 0.0-1.0 (default: 0.2) +- `-h, --help` - Show search help message + +**Output Formats:** +- Default: JSON format +- `--output plain` - Plain text format +- `--output table` - Table format with borders + +**Examples:** +```bash +search -q "machine learning" # Search all datasets (JSON output) +search -d datasets/kb1 -q "neural networks" # Search in kb1 +search -d datasets/kb1 -q "AI" --output plain # Plain text output +search -q "RAG" -k 20 -t 0.5 # Return 20 results with threshold 0.5 +``` + +#### `cat ` - Display content + +Display document content (if available). + +**Examples:** +```bash +cat myskills/doc.md # Show content of doc.md file +cat datasets/kb1/document.pdf # Error: cannot display binary file content ``` ## Command Examples @@ -71,6 +155,15 @@ SET DEFAULT LLM 'gpt-4'; SET DEFAULT EMBEDDING 'text-embedding-ada-002'; RESET DEFAULT LLM; +-- Context Engine (Virtual Filesystem) +ls; -- List all datasets (default 10) +ls -n 20; -- List 20 datasets +ls datasets/my_dataset; -- List documents in dataset +ls datasets/my_dataset -n 50; -- List 50 documents +ls datasets/my_dataset/info; -- Show dataset info +search -q "test"; -- Search all datasets (JSON output) +search -d datasets/my_dataset -q "test"; -- Search in specific dataset + -- Meta commands \? -- Show help \q -- Quit diff --git a/internal/cli/cli.go b/internal/cli/cli.go index b0d7a3848..e87e70f97 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -17,6 +17,8 @@ package cli import ( + "context" + "encoding/json" "errors" "fmt" "os" @@ -24,10 +26,13 @@ import ( "strconv" "strings" "syscall" + "unicode/utf8" "github.com/peterh/liner" "golang.org/x/term" "gopkg.in/yaml.v3" + + "ragflow/internal/cli/contextengine" ) // ConfigFile represents the rf.yml configuration file structure @@ -38,16 +43,28 @@ type ConfigFile struct { Password string `yaml:"password"` } +// OutputFormat represents the output format type +type OutputFormat string + +const ( + OutputFormatTable OutputFormat = "table" // Table format with borders + OutputFormatPlain OutputFormat = "plain" // Plain text, space-separated (no borders) + OutputFormatJSON OutputFormat = "json" // JSON format (reserved for future use) +) + // ConnectionArgs holds the parsed command line arguments type ConnectionArgs struct { - Host string - Port int - Password string - APIToken string - UserName string - Command string - ShowHelp bool - AdminMode bool + Host string + Port int + Password string + APIToken string + UserName string + Command string // Original command string (for SQL mode) + CommandArgs []string // Split command arguments (for ContextEngine mode) + IsSQLMode bool // true=SQL mode (quoted), false=ContextEngine mode (unquoted) + ShowHelp bool + AdminMode bool + OutputFormat OutputFormat // Output format: table, plain, json } // LoadDefaultConfigFile reads the rf.yml file from current directory if it exists @@ -111,13 +128,25 @@ func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { // First, scan args to check for help, config file, and admin mode var configFilePath string var adminMode bool = false + foundCommand := false for i := 0; i < len(args); i++ { arg := args[i] - if arg == "--help" || arg == "-help" { + // If we found a command (non-flag arg), stop processing global flags + // This allows subcommands like "search --help" to handle their own help + if !strings.HasPrefix(arg, "-") { + foundCommand = true + continue + } + // Only process --help as global help if it's before any command + if !foundCommand && (arg == "--help" || arg == "-help") { return &ConnectionArgs{ShowHelp: true}, nil } else if (arg == "-f" || arg == "--config") && i+1 < len(args) { configFilePath = args[i+1] i++ + } else if (arg == "-o" || arg == "--output") && i+1 < len(args) { + // -o/--output is allowed with config file, skip it and its value + i++ + continue } else if arg == "--admin" { adminMode = true } @@ -130,7 +159,6 @@ func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { // Parse arguments manually to support both short and long forms // and to handle priority: command line > config file > defaults - // Build result from config file first (if exists), then override with command line flags result := &ConnectionArgs{} if !adminMode { @@ -171,8 +199,18 @@ func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { // Override with command line flags (higher priority) // Handle both short and long forms manually + // Once we encounter a non-flag argument (command), stop parsing global flags + // Remaining args belong to the subcommand + foundCommand = false for i := 0; i < len(args); i++ { arg := args[i] + + // If we've found the command, collect remaining args as subcommand args + if foundCommand { + nonFlagArgs = append(nonFlagArgs, arg) + continue + } + switch arg { case "-h", "--host": if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { @@ -205,6 +243,20 @@ func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { if i+1 < len(args) { i++ } + case "-o", "--output": + // Parse output format + if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + format := args[i+1] + switch format { + case "plain": + result.OutputFormat = OutputFormatPlain + case "json": + result.OutputFormat = OutputFormatJSON + default: + result.OutputFormat = OutputFormatTable + } + i++ + } case "--admin", "-admin": result.AdminMode = true case "--help", "-help": @@ -214,6 +266,7 @@ func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { // Non-flag argument (command) if !strings.HasPrefix(arg, "-") { nonFlagArgs = append(nonFlagArgs, arg) + foundCommand = true } } } @@ -253,12 +306,42 @@ func ParseConnectionArgs(args []string) (*ConnectionArgs, error) { // Get command from remaining args (non-flag arguments) if len(nonFlagArgs) > 0 { - result.Command = strings.Join(nonFlagArgs, " ") + // Check if this is SQL mode or ContextEngine mode + // SQL mode: single argument that looks like SQL (e.g., "LIST DATASETS") + // ContextEngine mode: multiple arguments (e.g., "ls", "datasets") + if len(nonFlagArgs) == 1 && looksLikeSQL(nonFlagArgs[0]) { + // SQL mode: single argument that looks like SQL + result.IsSQLMode = true + result.Command = nonFlagArgs[0] + } else { + // ContextEngine mode: multiple arguments + result.IsSQLMode = false + result.CommandArgs = nonFlagArgs + // Also store joined version for backward compatibility + result.Command = strings.Join(nonFlagArgs, " ") + } } return result, nil } +// looksLikeSQL checks if a string looks like a SQL command +func looksLikeSQL(s string) bool { + s = strings.ToUpper(strings.TrimSpace(s)) + sqlPrefixes := []string{ + "LIST ", "SHOW ", "CREATE ", "DROP ", "ALTER ", + "LOGIN ", "REGISTER ", "PING", "GRANT ", "REVOKE ", + "SET ", "UNSET ", "UPDATE ", "DELETE ", "INSERT ", + "SELECT ", "DESCRIBE ", "EXPLAIN ", + } + for _, prefix := range sqlPrefixes { + if strings.HasPrefix(s, prefix) { + return true + } + } + return false +} + // PrintUsage prints the CLI usage information func PrintUsage() { fmt.Println(`RAGFlow CLI Client @@ -271,6 +354,7 @@ Options: -u, --user string Username for authentication -p, --password string Password for authentication -f, --config string Path to config file (YAML format) + -o, --output string Output format: table, plain, json (search defaults to json) --admin, -admin Run in admin mode --help Show this help message @@ -298,7 +382,8 @@ Configuration File: Note: api_token and user_name/password are mutually exclusive in config file. Commands: - SQL commands like: LOGIN USER 'email'; LIST USERS; etc. + SQL commands (use quotes): "LIST USERS", "CREATE USER 'email' 'password'", etc. + Context Engine commands (no quotes): ls datasets, search "keyword", cat path, etc. If no command is provided, CLI runs in interactive mode. `) } @@ -312,11 +397,13 @@ const historyFileName = ".ragflow_cli_history" // CLI represents the command line interface type CLI struct { - client *RAGFlowClient - prompt string - running bool - line *liner.State - args *ConnectionArgs + client *RAGFlowClient + contextEngine *contextengine.Engine + prompt string + running bool + line *liner.State + args *ConnectionArgs + outputFormat OutputFormat // Output format } // NewCLI creates a new CLI instance @@ -352,17 +439,41 @@ func NewCLIWithArgs(args *ConnectionArgs) (*CLI, error) { } } + // Apply API token if provided (from config file) + if args.APIToken != "" { + client.HTTPClient.APIToken = args.APIToken + client.HTTPClient.useAPIToken = true + } + + // Set output format + client.OutputFormat = args.OutputFormat + + // Auto-login if user and password are provided (from config file) + if args.UserName != "" && args.Password != "" && args.APIToken == "" { + if err := client.LoginUserInteractive(args.UserName, args.Password); err != nil { + line.Close() + return nil, fmt.Errorf("auto-login failed: %w", err) + } + } + // Set prompt based on server type prompt := "RAGFlow(user)> " if serverType == "admin" { prompt = "RAGFlow(admin)> " } + // Create context engine and register providers + engine := contextengine.NewEngine() + engine.RegisterProvider(contextengine.NewDatasetProvider(&httpClientAdapter{client: client.HTTPClient})) + engine.RegisterProvider(contextengine.NewFileProvider(&httpClientAdapter{client: client.HTTPClient})) + return &CLI{ - prompt: prompt, - client: client, - line: line, - args: args, + prompt: prompt, + client: client, + contextEngine: engine, + line: line, + args: args, + outputFormat: args.OutputFormat, }, nil } @@ -464,28 +575,360 @@ func (c *CLI) Run() error { } func (c *CLI) execute(input string) error { - p := NewParser(input) - cmd, err := p.Parse(c.args.AdminMode) + // Determine execution mode based on input and args + input = strings.TrimSpace(input) + + // Handle meta commands (start with \) + if strings.HasPrefix(input, "\\") { + p := NewParser(input) + cmd, err := p.Parse(c.args.AdminMode) + if err != nil { + return err + } + if cmd != nil && cmd.Type == "meta" { + return c.handleMetaCommand(cmd) + } + } + + // Check if we should use SQL mode or ContextEngine mode + isSQLMode := false + if c.args != nil && len(c.args.CommandArgs) > 0 { + // Non-interactive mode: use pre-determined mode from args + isSQLMode = c.args.IsSQLMode + } else { + // Interactive mode: determine based on input + isSQLMode = looksLikeSQL(input) + } + + if isSQLMode { + // SQL mode: use parser + p := NewParser(input) + cmd, err := p.Parse(c.args.AdminMode) + if err != nil { + return err + } + if cmd == nil { + return nil + } + // Execute SQL command using the client + var result ResponseIf + result, err = c.client.ExecuteCommand(cmd) + if result != nil { + result.SetOutputFormat(c.outputFormat) + result.PrintOut() + } + return err + } + + // ContextEngine mode: execute context engine command + return c.executeContextEngine(input) +} + +// executeContextEngine executes a Context Engine command +func (c *CLI) executeContextEngine(input string) error { + // Parse input into arguments + var args []string + if c.args != nil && len(c.args.CommandArgs) > 0 { + // Non-interactive mode: use pre-parsed args + args = c.args.CommandArgs + } else { + // Interactive mode: parse input + args = parseContextEngineArgs(input) + } + + if len(args) == 0 { + return fmt.Errorf("no command provided") + } + + // Check if we have a context engine + if c.contextEngine == nil { + return fmt.Errorf("context engine not available") + } + + cmdType := args[0] + cmdArgs := args[1:] + + // Build context engine command + var ceCmd *contextengine.Command + + switch cmdType { + case "ls", "list": + // Parse list command arguments + listOpts, err := parseListCommandArgs(cmdArgs) + if err != nil { + return err + } + if listOpts == nil { + // Help was printed + return nil + } + ceCmd = &contextengine.Command{ + Type: contextengine.CommandList, + Path: listOpts.Path, + Params: map[string]interface{}{ + "limit": listOpts.Limit, + }, + } + case "search": + // Parse search command arguments + searchOpts, err := parseSearchCommandArgs(cmdArgs) + if err != nil { + return err + } + if searchOpts == nil { + // Help was printed + return nil + } + // Determine the path for provider resolution + // Use first dir if specified, otherwise default to "datasets" + searchPath := "datasets" + if len(searchOpts.Dirs) > 0 { + searchPath = searchOpts.Dirs[0] + } + ceCmd = &contextengine.Command{ + Type: contextengine.CommandSearch, + Path: searchPath, + Params: map[string]interface{}{ + "query": searchOpts.Query, + "top_k": searchOpts.TopK, + "threshold": searchOpts.Threshold, + "dirs": searchOpts.Dirs, + }, + } + case "cat": + if len(cmdArgs) == 0 { + return fmt.Errorf("cat requires a path argument") + } + // Handle cat command directly since it returns []byte, not *Result + content, err := c.contextEngine.Cat(context.Background(), cmdArgs[0]) + if err != nil { + return err + } + if content == nil || len(content) == 0 { + fmt.Println("(empty file)") + } else if isBinaryContent(content) { + return fmt.Errorf("cannot display binary file content") + } else { + fmt.Println(string(content)) + } + return nil + default: + return fmt.Errorf("unknown context engine command: %s", cmdType) + } + + // Execute the command + result, err := c.contextEngine.Execute(context.Background(), ceCmd) if err != nil { return err } - if cmd == nil { - return nil + // Print result + // For search command, default to JSON format if not explicitly set to plain/table + format := c.outputFormat + if ceCmd.Type == contextengine.CommandSearch && format != OutputFormatPlain && format != OutputFormatTable { + format = OutputFormatJSON + } + // Get limit for list command + limit := 0 + if ceCmd.Type == contextengine.CommandList { + if l, ok := ceCmd.Params["limit"].(int); ok { + limit = l + } + } + c.printContextEngineResult(result, ceCmd.Type, format, limit) + return nil +} + +// parseContextEngineArgs parses Context Engine command arguments +// Supports simple space-separated args and quoted strings +func parseContextEngineArgs(input string) []string { + var args []string + var current strings.Builder + inQuote := false + var quoteChar rune + + for _, ch := range input { + switch ch { + case '"', '\'': + if !inQuote { + inQuote = true + quoteChar = ch + if current.Len() > 0 { + args = append(args, current.String()) + current.Reset() + } + } else if ch == quoteChar { + inQuote = false + args = append(args, current.String()) + current.Reset() + } else { + current.WriteRune(ch) + } + case ' ', '\t': + if inQuote { + current.WriteRune(ch) + } else if current.Len() > 0 { + args = append(args, current.String()) + current.Reset() + } + default: + current.WriteRune(ch) + } } - // Handle meta commands - if cmd.Type == "meta" { - return c.handleMetaCommand(cmd) + if current.Len() > 0 { + args = append(args, current.String()) } - // Execute the command using the client - var result ResponseIf - result, err = c.client.ExecuteCommand(cmd) - if result != nil { - result.PrintOut() + return args +} + +// printContextEngineResult prints the result of a context engine command +func (c *CLI) printContextEngineResult(result *contextengine.Result, cmdType contextengine.CommandType, format OutputFormat, limit int) { + if result == nil { + return + } + + switch cmdType { + case contextengine.CommandList: + if len(result.Nodes) == 0 { + fmt.Println("(empty)") + return + } + displayCount := len(result.Nodes) + if limit > 0 && displayCount > limit { + displayCount = limit + } + if format == OutputFormatPlain { + // Plain format: simple space-separated, no headers + for i := 0; i < displayCount; i++ { + node := result.Nodes[i] + fmt.Printf("%s %s %s %s\n", node.Name, node.Type, node.Path, node.CreatedAt.Format("2006-01-02 15:04")) + } + } else { + // Table format: with headers and aligned columns + fmt.Printf("%-30s %-12s %-50s %-20s\n", "NAME", "TYPE", "PATH", "CREATED") + fmt.Println(strings.Repeat("-", 112)) + for i := 0; i < displayCount; i++ { + node := result.Nodes[i] + created := node.CreatedAt.Format("2006-01-02 15:04") + if node.CreatedAt.IsZero() { + created = "-" + } + // Remove leading "/" from path for display + displayPath := node.Path + if strings.HasPrefix(displayPath, "/") { + displayPath = displayPath[1:] + } + fmt.Printf("%-30s %-12s %-50s %-20s\n", node.Name, node.Type, displayPath, created) + } + } + if limit > 0 && result.Total > limit { + fmt.Printf("\n... and %d more (use -n to show more)\n", result.Total-limit) + } + fmt.Printf("Total: %d\n", result.Total) + case contextengine.CommandSearch: + if len(result.Nodes) == 0 { + if format == OutputFormatJSON { + fmt.Println("[]") + } else { + fmt.Println("No results found") + } + return + } + // Build data for output (same fields for all formats: content, path, score) + type searchResult struct { + Content string `json:"content"` + Path string `json:"path"` + Score float64 `json:"score,omitempty"` + } + results := make([]searchResult, 0, len(result.Nodes)) + for _, node := range result.Nodes { + content := node.Name + if content == "" { + content = "(empty)" + } + displayPath := node.Path + if strings.HasPrefix(displayPath, "/") { + displayPath = displayPath[1:] + } + var score float64 + if s, ok := node.Metadata["similarity"].(float64); ok { + score = s + } else if s, ok := node.Metadata["_score"].(float64); ok { + score = s + } + results = append(results, searchResult{ + Content: content, + Path: displayPath, + Score: score, + }) + } + // Output based on format + if format == OutputFormatJSON { + jsonData, err := json.MarshalIndent(results, "", " ") + if err != nil { + fmt.Printf("Error marshaling JSON: %v\n", err) + return + } + fmt.Println(string(jsonData)) + } else if format == OutputFormatPlain { + // Plain format: simple space-separated, no borders + fmt.Printf("%-70s %-50s %-10s\n", "CONTENT", "PATH", "SCORE") + for i, sr := range results { + content := strings.Join(strings.Fields(sr.Content), " ") + if len(content) > 70 { + content = content[:67] + "..." + } + displayPath := sr.Path + if len(displayPath) > 50 { + displayPath = displayPath[:47] + "..." + } + scoreStr := "-" + if sr.Score > 0 { + scoreStr = fmt.Sprintf("%.4f", sr.Score) + } + fmt.Printf("%-70s %-50s %-10s\n", content, displayPath, scoreStr) + if i >= 99 { + fmt.Printf("\n... and %d more results\n", result.Total-i-1) + break + } + } + fmt.Printf("\nTotal: %d\n", result.Total) + } else { + // Table format: with borders + col1Width, col2Width, col3Width := 70, 50, 10 + sep := "+" + strings.Repeat("-", col1Width+2) + "+" + strings.Repeat("-", col2Width+2) + "+" + strings.Repeat("-", col3Width+2) + "+" + fmt.Println(sep) + fmt.Printf("| %-70s | %-50s | %-10s |\n", "CONTENT", "PATH", "SCORE") + fmt.Println(sep) + for i, sr := range results { + content := strings.Join(strings.Fields(sr.Content), " ") + if len(content) > 70 { + content = content[:67] + "..." + } + displayPath := sr.Path + if len(displayPath) > 50 { + displayPath = displayPath[:47] + "..." + } + scoreStr := "-" + if sr.Score > 0 { + scoreStr = fmt.Sprintf("%.4f", sr.Score) + } + fmt.Printf("| %-70s | %-50s | %-10s |\n", content, displayPath, scoreStr) + if i >= 99 { + fmt.Printf("\n... and %d more results\n", result.Total-i-1) + break + } + } + fmt.Println(sep) + fmt.Printf("Total: %d\n", result.Total) + } +case contextengine.CommandCat: + // Cat output is handled differently - it returns []byte, not *Result + // This case should not be reached in normal flow since Cat returns []byte directly + fmt.Println("Content retrieved") } - return err } func (c *CLI) handleMetaCommand(cmd *Command) error { @@ -574,24 +1017,25 @@ Commands (User Mode): CREATE INDEX DOC_META; - Create doc meta index DROP INDEX DOC_META; - Drop doc meta index -Commands (Admin Mode): - LIST USERS; - List all users - SHOW USER 'email'; - Show user details - CREATE USER 'email' 'password'; - Create new user - DROP USER 'email'; - Delete user - ALTER USER PASSWORD 'email' 'new_password'; - Change user password - ALTER USER ACTIVE 'email' on/off; - Activate/deactivate user - GRANT ADMIN 'email'; - Grant admin role - REVOKE ADMIN 'email'; - Revoke admin role - LIST SERVICES; - List services - SHOW SERVICE ; - Show service details - PING; - Ping server - ... and many more +Context Engine Commands (no quotes): + ls [path] - List resources + e.g., ls - List root (providers and folders) + e.g., ls datasets - List all datasets + e.g., ls datasets/kb1 - Show dataset info + e.g., ls myfolder - List files in 'myfolder' (file_manager) + list [path] - Same as ls + search [options] - Search resources in datasets + Use 'search -h' for detailed options + cat - Show file content + e.g., cat files/docs/file.txt - Show file content + Note: cat datasets or cat datasets/kb1 will error -Meta Commands: - \? or \h - Show this help - \q or \quit - Exit CLI - \c or \clear - Clear screen +Examples: + ragflow_cli -f rf.yml "LIST USERS" # SQL mode (with quotes) + ragflow_cli -f rf.yml ls datasets # Context Engine mode (no quotes) + ragflow_cli -f rf.yml ls files # List files in root + ragflow_cli -f rf.yml cat datasets # Error: datasets is a directory + ragflow_cli -f rf.yml ls files/myfolder # List folder contents For more information, see documentation. ` @@ -600,7 +1044,10 @@ For more information, see documentation. // Cleanup performs cleanup before exit func (c *CLI) Cleanup() { - fmt.Println("\nCleaning up...") + // Close liner to restore terminal settings + if c.line != nil { + c.line.Close() + } } // RunInteractive runs the CLI in interactive mode @@ -624,6 +1071,9 @@ func RunInteractive() error { // RunSingleCommand executes a single command and exits func (c *CLI) RunSingleCommand(command string) error { + // Ensure cleanup is called on exit to restore terminal settings + defer c.Cleanup() + // Execute the command if err := c.execute(command); err != nil { return err @@ -659,3 +1109,227 @@ func (c *CLI) VerifyAuth() error { _, err := c.client.ExecuteCommand(cmd) return err } + +// isBinaryContent checks if content is binary (contains null bytes or invalid UTF-8) +func isBinaryContent(content []byte) bool { + // Check for null bytes (binary file indicator) + for _, b := range content { + if b == 0 { + return true + } + } + // Check valid UTF-8 + return !utf8.Valid(content) +} + +// SearchCommandOptions holds parsed search command options +type SearchCommandOptions struct { + Query string + TopK int + Threshold float64 + Dirs []string +} + +// ListCommandOptions holds parsed list command options +type ListCommandOptions struct { + Path string + Limit int +} + +// parseSearchCommandArgs parses search command arguments +// Format: search [-d dir1] [-d dir2] ... -q query [-k top_k] [-t threshold] +// search -h|--help (shows help) +func parseSearchCommandArgs(args []string) (*SearchCommandOptions, error) { + opts := &SearchCommandOptions{ + TopK: 10, + Threshold: 0.2, + Dirs: []string{}, + } + + // Check for help flag + for _, arg := range args { + if arg == "-h" || arg == "--help" { + printSearchHelp() + return nil, nil + } + } + + // Parse arguments + i := 0 + for i < len(args) { + arg := args[i] + + switch arg { + case "-d", "--dir": + if i+1 >= len(args) { + return nil, fmt.Errorf("missing value for %s flag", arg) + } + opts.Dirs = append(opts.Dirs, args[i+1]) + i += 2 + case "-q", "--query": + if i+1 >= len(args) { + return nil, fmt.Errorf("missing value for %s flag", arg) + } + opts.Query = args[i+1] + i += 2 + case "-k", "--top-k": + if i+1 >= len(args) { + return nil, fmt.Errorf("missing value for %s flag", arg) + } + topK, err := strconv.Atoi(args[i+1]) + if err != nil { + return nil, fmt.Errorf("invalid top-k value: %s", args[i+1]) + } + opts.TopK = topK + i += 2 + case "-t", "--threshold": + if i+1 >= len(args) { + return nil, fmt.Errorf("missing value for %s flag", arg) + } + threshold, err := strconv.ParseFloat(args[i+1], 64) + if err != nil { + return nil, fmt.Errorf("invalid threshold value: %s", args[i+1]) + } + opts.Threshold = threshold + i += 2 + default: + // If it doesn't start with -, it might be a positional argument + if !strings.HasPrefix(arg, "-") { + // For backwards compatibility: if no -q flag and this is the last arg, treat as query + if opts.Query == "" && i == len(args)-1 { + opts.Query = arg + } else if opts.Query == "" && len(args) > 0 && i < len(args)-1 { + // Old format: search [path] query + // Treat first non-flag as path, rest as query + opts.Dirs = append(opts.Dirs, arg) + // Join remaining args as query + remainingArgs := args[i+1:] + queryParts := []string{} + for _, part := range remainingArgs { + if !strings.HasPrefix(part, "-") { + queryParts = append(queryParts, part) + } + } + opts.Query = strings.Join(queryParts, " ") + break + } + } else { + return nil, fmt.Errorf("unknown flag: %s", arg) + } + i++ + } + } + + // Validate required parameters + if opts.Query == "" { + return nil, fmt.Errorf("query is required (use -q or --query)") + } + + // If no directories specified, search in all datasets (empty path means all) + if len(opts.Dirs) == 0 { + opts.Dirs = []string{"datasets"} + } + + return opts, nil +} + +// printSearchHelp prints help for the search command +func printSearchHelp() { + help := `Search command usage: search [options] + +Search for content in datasets. Currently only supports searching in datasets. + +Options: + -d, --dir Directory to search in (can be specified multiple times) + Currently only supports paths under 'datasets/' + Example: -d datasets/kb1 -d datasets/kb2 + -q, --query Search query (required) + Example: -q "machine learning" + -k, --top-k Number of top results to return (default: 10) + Example: -k 20 + -t, --threshold Similarity threshold, 0.0-1.0 (default: 0.2) + Example: -t 0.5 + -h, --help Show this help message + +Output: + Default output format is JSON. Use --output plain or --output table for other formats. + +Examples: + search -d datasets/kb1 -q "neural networks" # Search in kb1 (JSON output) + search -d datasets/kb1 -q "AI" --output plain # Search with plain text output + search -q "data mining" # Search all datasets + search -q "RAG" -k 20 -t 0.5 # Return 20 results with threshold 0.5 +` + fmt.Println(help) +} + +// printListHelp prints help for the list/ls command +func printListHelp() { + help := `List command usage: ls [path] [options] + +List contents of a path in the context filesystem. + +Arguments: + [path] Path to list (default: root - shows all providers and folders) + Examples: datasets, datasets/kb1, myfolder + +Options: + -n, --limit Maximum number of items to display (default: 10) + Example: -n 20 + -h, --help Show this help message + +Examples: + ls # List root (all providers and file_manager folders) + ls datasets # List all datasets + ls datasets/kb1 # List files in kb1 dataset (default 10 items) + ls myfolder # List files in file_manager folder 'myfolder' + ls -n 5 # List 5 items at root +` + fmt.Println(help) +} + +// parseListCommandArgs parses list/ls command arguments +// Format: ls [path] [-n limit] [-h|--help] +func parseListCommandArgs(args []string) (*ListCommandOptions, error) { + opts := &ListCommandOptions{ + Path: "", // Empty path means list root (all providers and file_manager folders) + Limit: 10, + } + + // Check for help flag + for _, arg := range args { + if arg == "-h" || arg == "--help" { + printListHelp() + return nil, nil + } + } + + // Parse arguments + i := 0 + for i < len(args) { + arg := args[i] + + switch arg { + case "-n", "--limit": + if i+1 >= len(args) { + return nil, fmt.Errorf("missing value for %s flag", arg) + } + limit, err := strconv.Atoi(args[i+1]) + if err != nil { + return nil, fmt.Errorf("invalid limit value: %s", args[i+1]) + } + opts.Limit = limit + i += 2 + default: + // If it doesn't start with -, treat as path + if !strings.HasPrefix(arg, "-") { + opts.Path = arg + } else { + return nil, fmt.Errorf("unknown flag: %s", arg) + } + i++ + } + } + + return opts, nil +} diff --git a/internal/cli/client.go b/internal/cli/client.go index 0445adc36..11145211f 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -18,13 +18,17 @@ package cli import ( "bufio" + "context" "encoding/json" "fmt" "os" "os/exec" "strings" "syscall" + "time" "unsafe" + + ce "ragflow/internal/cli/contextengine" ) // PasswordPromptFunc is a function type for password input @@ -35,6 +39,8 @@ type RAGFlowClient struct { HTTPClient *HTTPClient ServerType string // "admin" or "user" PasswordPrompt PasswordPromptFunc // Function for password input + OutputFormat OutputFormat // Output format: table, plain, json + ContextEngine *ce.Engine // Context Engine for virtual filesystem } // NewRAGFlowClient creates a new RAGFlow client @@ -47,10 +53,54 @@ func NewRAGFlowClient(serverType string) *RAGFlowClient { httpClient.Port = 9380 } - return &RAGFlowClient{ + client := &RAGFlowClient{ HTTPClient: httpClient, ServerType: serverType, } + + // Initialize Context Engine + client.initContextEngine() + + return client +} + +// initContextEngine initializes the Context Engine with all providers +func (c *RAGFlowClient) initContextEngine() { + engine := ce.NewEngine() + + // Register providers + engine.RegisterProvider(ce.NewDatasetProvider(&httpClientAdapter{c.HTTPClient})) + + c.ContextEngine = engine +} + +// httpClientAdapter adapts HTTPClient to ce.HTTPClientInterface +type httpClientAdapter struct { + client *HTTPClient +} + +func (a *httpClientAdapter) Request(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (*ce.HTTPResponse, error) { + // Auto-detect auth kind based on available tokens + // If authKind is "auto" or empty, determine based on token availability + if authKind == "auto" || authKind == "" { + if a.client.useAPIToken && a.client.APIToken != "" { + authKind = "api" + } else if a.client.LoginToken != "" { + authKind = "web" + } else { + authKind = "web" // default + } + } + resp, err := a.client.Request(method, path, useAPIBase, authKind, headers, jsonBody) + if err != nil { + return nil, err + } + return &ce.HTTPResponse{ + StatusCode: resp.StatusCode, + Body: resp.Body, + Headers: resp.Headers, + Duration: resp.Duration, + }, nil } // LoginUserInteractive performs interactive login with username and password @@ -413,6 +463,11 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.CreateDocMetaIndex(cmd) case "drop_doc_meta_index": return c.DropDocMetaIndex(cmd) + // ContextEngine commands + case "ce_ls": + return c.CEList(cmd) + case "ce_search": + return c.CESearch(cmd) // TODO: Implement other commands default: return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type) @@ -432,13 +487,15 @@ type ResponseIf interface { Type() string PrintOut() TimeCost() float64 + SetOutputFormat(format OutputFormat) } type CommonResponse struct { - Code int `json:"code"` - Data []map[string]interface{} `json:"data"` - Message string `json:"message"` - Duration float64 + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat } func (r *CommonResponse) Type() string { @@ -449,9 +506,13 @@ func (r *CommonResponse) TimeCost() float64 { return r.Duration } +func (r *CommonResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + func (r *CommonResponse) PrintOut() { if r.Code == 0 { - PrintTableSimple(r.Data) + PrintTableSimpleByFormat(r.Data, r.outputFormat) } else { fmt.Println("ERROR") fmt.Printf("%d, %s\n", r.Code, r.Message) @@ -459,10 +520,11 @@ func (r *CommonResponse) PrintOut() { } type CommonDataResponse struct { - Code int `json:"code"` - Data map[string]interface{} `json:"data"` - Message string `json:"message"` - Duration float64 + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat } func (r *CommonDataResponse) Type() string { @@ -473,11 +535,15 @@ func (r *CommonDataResponse) TimeCost() float64 { return r.Duration } +func (r *CommonDataResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + func (r *CommonDataResponse) PrintOut() { if r.Code == 0 { table := make([]map[string]interface{}, 0) table = append(table, r.Data) - PrintTableSimple(table) + PrintTableSimpleByFormat(table, r.outputFormat) } else { fmt.Println("ERROR") fmt.Printf("%d, %s\n", r.Code, r.Message) @@ -485,9 +551,10 @@ func (r *CommonDataResponse) PrintOut() { } type SimpleResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Duration float64 + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat } func (r *SimpleResponse) Type() string { @@ -498,6 +565,10 @@ func (r *SimpleResponse) TimeCost() float64 { return r.Duration } +func (r *SimpleResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + func (r *SimpleResponse) PrintOut() { if r.Code == 0 { fmt.Println("SUCCESS") @@ -508,9 +579,10 @@ func (r *SimpleResponse) PrintOut() { } type RegisterResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Duration float64 + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat } func (r *RegisterResponse) Type() string { @@ -521,6 +593,10 @@ func (r *RegisterResponse) TimeCost() float64 { return r.Duration } +func (r *RegisterResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + func (r *RegisterResponse) PrintOut() { if r.Code == 0 { fmt.Println("Register successfully") @@ -536,12 +612,17 @@ type BenchmarkResponse struct { SuccessCount int `json:"success_count"` FailureCount int `json:"failure_count"` Concurrency int + outputFormat OutputFormat } func (r *BenchmarkResponse) Type() string { return "benchmark" } +func (r *BenchmarkResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + func (r *BenchmarkResponse) PrintOut() { if r.Code != 0 { fmt.Printf("ERROR, Code: %d\n", r.Code) @@ -565,10 +646,11 @@ func (r *BenchmarkResponse) TimeCost() float64 { } type KeyValueResponse struct { - Code int `json:"code"` - Key string `json:"key"` - Value string `json:"data"` - Duration float64 + Code int `json:"code"` + Key string `json:"key"` + Value string `json:"data"` + Duration float64 + outputFormat OutputFormat } func (r *KeyValueResponse) Type() string { @@ -579,6 +661,10 @@ func (r *KeyValueResponse) TimeCost() float64 { return r.Duration } +func (r *KeyValueResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + func (r *KeyValueResponse) PrintOut() { if r.Code == 0 { table := make([]map[string]interface{}, 0) @@ -587,9 +673,175 @@ func (r *KeyValueResponse) PrintOut() { "key": r.Key, "value": r.Value, }) - PrintTableSimple(table) + PrintTableSimpleByFormat(table, r.outputFormat) } else { fmt.Println("ERROR") fmt.Printf("%d\n", r.Code) } } + +// ==================== ContextEngine Commands ==================== + +// CEListResponse represents the response for ls command +type CEListResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *CEListResponse) Type() string { return "ce_ls" } +func (r *CEListResponse) TimeCost() float64 { return r.Duration } +func (r *CEListResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format } +func (r *CEListResponse) PrintOut() { + if r.Code == 0 { + PrintTableSimpleByFormat(r.Data, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +// CEList handles the ls command - lists nodes using Context Engine +func (c *RAGFlowClient) CEList(cmd *Command) (ResponseIf, error) { + // Get path from command params, default to "datasets" + path, _ := cmd.Params["path"].(string) + if path == "" { + path = "datasets" + } + + // Parse options + opts := &ce.ListOptions{} + if recursive, ok := cmd.Params["recursive"].(bool); ok { + opts.Recursive = recursive + } + if limit, ok := cmd.Params["limit"].(int); ok { + opts.Limit = limit + } + if offset, ok := cmd.Params["offset"].(int); ok { + opts.Offset = offset + } + + // Execute list command through Context Engine + ctx := context.Background() + result, err := c.ContextEngine.List(ctx, path, opts) + if err != nil { + return nil, err + } + + // Convert to response + var response CEListResponse + response.outputFormat = c.OutputFormat + response.Code = 0 + response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat)) + + return &response, nil +} + +// getStringValue safely converts interface{} to string +func getStringValue(v interface{}) string { + if v == nil { + return "" + } + if s, ok := v.(string); ok { + return s + } + return fmt.Sprintf("%v", v) +} + +// formatTimeValue converts a timestamp (milliseconds or string) to readable format +func formatTimeValue(v interface{}) string { + if v == nil { + return "" + } + + var ts int64 + switch val := v.(type) { + case float64: + ts = int64(val) + case int64: + ts = val + case int: + ts = int64(val) + case string: + // Try to parse as number + if _, err := fmt.Sscanf(val, "%d", &ts); err != nil { + // If it's already a formatted date string, return as is + return val + } + default: + return fmt.Sprintf("%v", v) + } + + // Convert milliseconds to seconds if timestamp is in milliseconds (13 digits) + if ts > 1e12 { + ts = ts / 1000 + } + + t := time.Unix(ts, 0) + return t.Format("2006-01-02 15:04:05") +} + +// CESearchResponse represents the response for search command +type CESearchResponse struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Total int `json:"total"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *CESearchResponse) Type() string { return "ce_search" } +func (r *CESearchResponse) TimeCost() float64 { return r.Duration } +func (r *CESearchResponse) SetOutputFormat(format OutputFormat) { r.outputFormat = format } +func (r *CESearchResponse) PrintOut() { + if r.Code == 0 { + fmt.Printf("Found %d results:\n", r.Total) + PrintTableSimpleByFormat(r.Data, r.outputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +// CESearch handles the search command using Context Engine +func (c *RAGFlowClient) CESearch(cmd *Command) (ResponseIf, error) { + // Get path and query from command params + path, _ := cmd.Params["path"].(string) + if path == "" { + path = "datasets" + } + query, _ := cmd.Params["query"].(string) + + // Parse options + opts := &ce.SearchOptions{ + Query: query, + } + if limit, ok := cmd.Params["limit"].(int); ok { + opts.Limit = limit + } + if offset, ok := cmd.Params["offset"].(int); ok { + opts.Offset = offset + } + if recursive, ok := cmd.Params["recursive"].(bool); ok { + opts.Recursive = recursive + } + + // Execute search command through Context Engine + ctx := context.Background() + result, err := c.ContextEngine.Search(ctx, path, opts) + if err != nil { + return nil, err + } + + // Convert to response + var response CESearchResponse + response.outputFormat = c.OutputFormat + response.Code = 0 + response.Total = result.Total + response.Data = ce.FormatNodes(result.Nodes, string(c.OutputFormat)) + + return &response, nil +} diff --git a/internal/cli/contextengine/README.md b/internal/cli/contextengine/README.md new file mode 100644 index 000000000..26548823a --- /dev/null +++ b/internal/cli/contextengine/README.md @@ -0,0 +1,49 @@ +# ContextFS - Context Engine File System + +ContextFS is a context engine interface for RAGFlow, providing users with a Unix-like file system interface to manage datasets, tools, skills, and memories. + +## Directory Structure + +``` +user_id/ +├── datasets/ +│ └── my_dataset/ +│ └── ... +├── tools/ +│ ├── registry.json +│ └── tool_name/ +│ ├── DOC.md +│ └── ... +├── skills/ +│ ├── registry.json +│ └── skill_name/ +│ ├── SKILL.md +│ └── ... +└── memories/ + └── memory_id/ + ├── sessions/ + │ ├── messages/ + │ ├── summaries/ + │ │ └── session_id/ + │ │ └── summary-{datetime}.md + │ └── tools/ + │ └── session_id/ + │ └── {tool_name}.md # User level of memory on Tools usage + ├── users/ + │ ├── profile.md + │ ├── preferences/ + │ └── entities/ + └── agents/ + └── agent_space/ + ├── tools/ + │ └── {tool_name}.md # Agent level of memory on Tools usage + └── skills/ + └── {skill_name}.md # Agent level of memory on Skills usage +``` + + +## Supported Commands + +- `ls [path]` - List directory contents +- `cat ` - Display file contents(only for text files) +- `search ` - Search content diff --git a/internal/cli/contextengine/dataset_provider.go b/internal/cli/contextengine/dataset_provider.go new file mode 100644 index 000000000..daf3e41e4 --- /dev/null +++ b/internal/cli/contextengine/dataset_provider.go @@ -0,0 +1,781 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package contextengine + +import ( + stdctx "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" +) + +// HTTPResponse represents an HTTP response +type HTTPResponse struct { + StatusCode int + Body []byte + Headers map[string][]string + Duration float64 +} + +// HTTPClientInterface defines the interface needed from HTTPClient +type HTTPClientInterface interface { + Request(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (*HTTPResponse, error) +} + +// DatasetProvider handles datasets and their documents +// Path structure: +// - datasets/ -> List all datasets +// - datasets/{name} -> List documents in dataset +// - datasets/{name}/{doc_name} -> Get document info +type DatasetProvider struct { + BaseProvider + httpClient HTTPClientInterface +} + +// NewDatasetProvider creates a new DatasetProvider +func NewDatasetProvider(httpClient HTTPClientInterface) *DatasetProvider { + return &DatasetProvider{ + BaseProvider: BaseProvider{ + name: "datasets", + description: "Dataset management provider", + rootPath: "datasets", + }, + httpClient: httpClient, + } +} + +// Supports returns true if this provider can handle the given path +func (p *DatasetProvider) Supports(path string) bool { + normalized := normalizePath(path) + return normalized == "datasets" || strings.HasPrefix(normalized, "datasets/") +} + +// List lists nodes at the given path +func (p *DatasetProvider) List(ctx stdctx.Context, subPath string, opts *ListOptions) (*Result, error) { + // subPath is the path relative to "datasets/" + // Empty subPath means list all datasets + // "{name}/files" means list documents in a dataset + + // Check if trying to access hidden .knowledgebase + if subPath == ".knowledgebase" || strings.HasPrefix(subPath, ".knowledgebase/") { + return nil, fmt.Errorf("invalid path: .knowledgebase is not accessible") + } + + if subPath == "" { + return p.listDatasets(ctx, opts) + } + + parts := SplitPath(subPath) + if len(parts) == 1 { + // datasets/{name} - list documents in the dataset (default behavior) + return p.listDocuments(ctx, parts[0], opts) + } + + if len(parts) == 2 { + // datasets/{name}/{doc_name} - get document info + return p.getDocumentNode(ctx, parts[0], parts[1]) + } + + return nil, fmt.Errorf("invalid path: %s", subPath) +} + +// Search searches for datasets or documents +func (p *DatasetProvider) Search(ctx stdctx.Context, subPath string, opts *SearchOptions) (*Result, error) { + if opts.Query == "" { + return p.List(ctx, subPath, &ListOptions{ + Limit: opts.Limit, + Offset: opts.Offset, + }) + } + + // If searching under a specific dataset's files + parts := SplitPath(subPath) + if len(parts) >= 2 && parts[1] == "files" { + datasetName := parts[0] + return p.searchDocuments(ctx, datasetName, opts) + } + + // Otherwise search datasets + return p.searchDatasets(ctx, opts) +} + +// Cat retrieves document content +// For datasets: +// - cat datasets -> Error: datasets is a directory, not a file +// - cat datasets/kb_name -> Error: kb_name is a directory, not a file +// - cat datasets/kb_name/doc_name -> Would retrieve document content (if implemented) +func (p *DatasetProvider) Cat(ctx stdctx.Context, subPath string) ([]byte, error) { + if subPath == "" { + return nil, fmt.Errorf("'datasets' is a directory, not a file") + } + + parts := SplitPath(subPath) + if len(parts) == 1 { + // datasets/{name} - this is a dataset (directory) + return nil, fmt.Errorf("'%s' is a directory, not a file", parts[0]) + } + + if len(parts) == 2 { + // datasets/{name}/{doc_name} - this could be a document + // For now, document content retrieval is not implemented + return nil, fmt.Errorf("document content retrieval not yet implemented for '%s'", parts[1]) + } + + return nil, fmt.Errorf("invalid path for cat: %s", subPath) +} + +// ==================== Dataset Operations ==================== + +func (p *DatasetProvider) listDatasets(ctx stdctx.Context, opts *ListOptions) (*Result, error) { + resp, err := p.httpClient.Request("GET", "/datasets", true, "auto", nil, nil) + if err != nil { + return nil, err + } + + var apiResp struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return nil, err + } + + if apiResp.Code != 0 { + return nil, fmt.Errorf("API error: %s", apiResp.Message) + } + + nodes := make([]*Node, 0, len(apiResp.Data)) + for _, ds := range apiResp.Data { + node := p.datasetToNode(ds) + // Skip hidden .knowledgebase dataset (trim whitespace for safety) + if strings.TrimSpace(node.Name) == ".knowledgebase" { + continue + } + nodes = append(nodes, node) + } + + total := len(nodes) + + // Apply limit if specified + if opts != nil && opts.Limit > 0 && opts.Limit < len(nodes) { + nodes = nodes[:opts.Limit] + } + + return &Result{ + Nodes: nodes, + Total: total, + }, nil +} + +func (p *DatasetProvider) getDataset(ctx stdctx.Context, name string) (*Node, error) { + // Check if trying to access hidden .knowledgebase + if name == ".knowledgebase" { + return nil, fmt.Errorf("invalid path: .knowledgebase is not accessible") + } + + // First list all datasets to find the one with matching name + resp, err := p.httpClient.Request("GET", "/datasets", true, "auto", nil, nil) + if err != nil { + return nil, err + } + + var apiResp struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return nil, err + } + + if apiResp.Code != 0 { + return nil, fmt.Errorf("API error: %s", apiResp.Message) + } + + for _, ds := range apiResp.Data { + if getString(ds["name"]) == name { + return p.datasetToNode(ds), nil + } + } + + return nil, fmt.Errorf("%s: dataset '%s'", ErrNotFound, name) +} + +func (p *DatasetProvider) searchDatasets(ctx stdctx.Context, opts *SearchOptions) (*Result, error) { + // If no query is provided, just list datasets + if opts.Query == "" { + return p.listDatasets(ctx, &ListOptions{ + Limit: opts.Limit, + Offset: opts.Offset, + }) + } + + // Use retrieval API for semantic search + return p.searchWithRetrieval(ctx, opts) +} + +// searchWithRetrieval performs semantic search using the retrieval API +func (p *DatasetProvider) searchWithRetrieval(ctx stdctx.Context, opts *SearchOptions) (*Result, error) { + // Determine kb_ids to search in + var kbIDs []string + var datasetsToSearch []*Node + + if len(opts.Dirs) > 0 && opts.Dirs[0] != "datasets" { + // Search in specific datasets + for _, dir := range opts.Dirs { + // Extract dataset name from path (e.g., "datasets/kb1" -> "kb1") + datasetName := dir + if strings.HasPrefix(dir, "datasets/") { + datasetName = dir[len("datasets/"):] + } + ds, err := p.getDataset(ctx, datasetName) + if err != nil { + // Try case-insensitive match + allResult, listErr := p.listDatasets(ctx, nil) + if listErr == nil { + for _, d := range allResult.Nodes { + if strings.EqualFold(d.Name, datasetName) { + ds = d + err = nil + break + } + } + } + if err != nil { + return nil, fmt.Errorf("dataset not found: %s", datasetName) + } + } + datasetsToSearch = append(datasetsToSearch, ds) + kbID := getString(ds.Metadata["id"]) + if kbID != "" { + kbIDs = append(kbIDs, kbID) + } + } + } else { + // Search in all datasets + allResult, err := p.listDatasets(ctx, nil) + if err != nil { + return nil, err + } + datasetsToSearch = allResult.Nodes + for _, ds := range datasetsToSearch { + kbID := getString(ds.Metadata["id"]) + if kbID != "" { + kbIDs = append(kbIDs, kbID) + } + } + } + + if len(kbIDs) == 0 { + return &Result{ + Nodes: []*Node{}, + Total: 0, + }, nil + } + + // Build kb_id -> dataset name mapping + kbIDToName := make(map[string]string) + for _, ds := range datasetsToSearch { + kbID := getString(ds.Metadata["id"]) + if kbID != "" && ds.Name != "" { + kbIDToName[kbID] = ds.Name + } + } + + // Build retrieval request + payload := map[string]interface{}{ + "kb_id": kbIDs, + "question": opts.Query, + } + + // Set top_k (default to 10 if not specified) + topK := opts.TopK + if topK <= 0 { + topK = 10 + } + payload["top_k"] = topK + + // Set similarity threshold (default to 0.2 if not specified to match UI behavior) + threshold := opts.Threshold + if threshold <= 0 { + threshold = 0.2 + } + payload["similarity_threshold"] = threshold + + // Call retrieval API (useAPIBase=false because the route is /v1/chunk/retrieval_test, not /api/v1/...) + resp, err := p.httpClient.Request("POST", "/chunk/retrieval_test", false, "auto", nil, payload) + if err != nil { + return nil, fmt.Errorf("retrieval request failed: %w", err) + } + + var apiResp struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return nil, err + } + + if apiResp.Code != 0 { + return nil, fmt.Errorf("API error: %s", apiResp.Message) + } + + // Parse chunks from response + var nodes []*Node + if chunksData, ok := apiResp.Data["chunks"].([]interface{}); ok { + for _, chunk := range chunksData { + if chunkMap, ok := chunk.(map[string]interface{}); ok { + node := p.chunkToNodeWithKBMapping(chunkMap, kbIDToName) + nodes = append(nodes, node) + } + } + } + + // Apply top_k limit if specified (API may return more results) + if topK > 0 && len(nodes) > topK { + nodes = nodes[:topK] + } + + return &Result{ + Nodes: nodes, + Total: len(nodes), + }, nil +} + +// chunkToNodeWithKBMapping converts a chunk map to a Node with kb_id -> name mapping +func (p *DatasetProvider) chunkToNodeWithKBMapping(chunk map[string]interface{}, kbIDToName map[string]string) *Node { + // Extract chunk content - try multiple field names + content := "" + if v, ok := chunk["content_with_weight"].(string); ok && v != "" { + content = v + } else if v, ok := chunk["content"].(string); ok && v != "" { + content = v + } else if v, ok := chunk["content_ltks"].(string); ok && v != "" { + content = v + } else if v, ok := chunk["text"].(string); ok && v != "" { + content = v + } + + // Get chunk_id for URI + chunkID := "" + if v, ok := chunk["chunk_id"].(string); ok { + chunkID = v + } else if v, ok := chunk["id"].(string); ok { + chunkID = v + } + + // Get document name and ID + docName := "" + if v, ok := chunk["docnm_kwd"].(string); ok && v != "" { + docName = v + } else if v, ok := chunk["docnm"].(string); ok && v != "" { + docName = v + } else if v, ok := chunk["doc_name"].(string); ok && v != "" { + docName = v + } + + docID := "" + if v, ok := chunk["doc_id"].(string); ok && v != "" { + docID = v + } + + // Get dataset/kb name from mapping or chunk data + datasetName := "" + datasetID := "" + + // First try to get kb_id from chunk (could be string or array) + if v, ok := chunk["kb_id"].(string); ok && v != "" { + datasetID = v + } else if v, ok := chunk["kb_id"].([]interface{}); ok && len(v) > 0 { + if s, ok := v[0].(string); ok { + datasetID = s + } + } + + // Look up dataset name from mapping using kb_id + if datasetID != "" && kbIDToName != nil { + if name, ok := kbIDToName[datasetID]; ok && name != "" { + datasetName = name + } + } + + // Fallback to kb_name from chunk if mapping doesn't have it + if datasetName == "" { + if v, ok := chunk["kb_name"].(string); ok && v != "" { + datasetName = v + } + } + + // Build URI path: prefer names over IDs for readability + // Format: datasets/{dataset_name}/{doc_name} + path := "/datasets" + if datasetName != "" { + path += "/" + datasetName + } else if datasetID != "" { + path += "/" + datasetID + } + if docName != "" { + path += "/" + docName + } else if docID != "" { + path += "/" + docID + } + + // Use doc_name or chunk_id as the name if content is empty + name := content + if name == "" { + if docName != "" { + name = docName + } else if chunkID != "" { + name = "chunk:" + chunkID[:min(len(chunkID), 16)] + } else { + name = "(empty)" + } + } + + node := &Node{ + Name: name, + Path: path, + Type: NodeTypeDocument, + Metadata: chunk, + } + + // Parse timestamps if available + if createTime, ok := chunk["create_time"]; ok { + node.CreatedAt = parseTime(createTime) + } + if updateTime, ok := chunk["update_time"]; ok { + node.UpdatedAt = parseTime(updateTime) + } + + return node +} + +// chunkToNode converts a chunk map to a Node (legacy, uses chunk data only) +func (p *DatasetProvider) chunkToNode(chunk map[string]interface{}) *Node { + return p.chunkToNodeWithKBMapping(chunk, nil) +} + +// ==================== Document Operations ==================== + +func (p *DatasetProvider) listDocuments(ctx stdctx.Context, datasetName string, opts *ListOptions) (*Result, error) { + // First get the dataset ID + ds, err := p.getDataset(ctx, datasetName) + if err != nil { + return nil, err + } + + datasetID := getString(ds.Metadata["id"]) + if datasetID == "" { + return nil, fmt.Errorf("dataset ID not found") + } + + // Build query parameters + params := make(map[string]string) + if opts != nil { + if opts.Limit > 0 { + params["page_size"] = fmt.Sprintf("%d", opts.Limit) + } + if opts.Offset > 0 { + params["page"] = fmt.Sprintf("%d", opts.Offset/opts.Limit+1) + } + } + + path := fmt.Sprintf("/datasets/%s/documents", datasetID) + resp, err := p.httpClient.Request("GET", path, true, "auto", params, nil) + if err != nil { + return nil, err + } + + var apiResp struct { + Code int `json:"code"` + Data struct { + Docs []map[string]interface{} `json:"docs"` + } `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return nil, err + } + + if apiResp.Code != 0 { + return nil, fmt.Errorf("API error: %s", apiResp.Message) + } + + nodes := make([]*Node, 0, len(apiResp.Data.Docs)) + for _, doc := range apiResp.Data.Docs { + node := p.documentToNode(doc, datasetName) + nodes = append(nodes, node) + } + + return &Result{ + Nodes: nodes, + Total: len(nodes), + }, nil +} + +func (p *DatasetProvider) getDocumentNode(ctx stdctx.Context, datasetName, docName string) (*Result, error) { + node, err := p.getDocument(ctx, datasetName, docName) + if err != nil { + return nil, err + } + return &Result{ + Nodes: []*Node{node}, + Total: 1, + }, nil +} + +func (p *DatasetProvider) getDocument(ctx stdctx.Context, datasetName, docName string) (*Node, error) { + // List all documents and find the matching one + result, err := p.listDocuments(ctx, datasetName, nil) + if err != nil { + return nil, err + } + + for _, node := range result.Nodes { + if node.Name == docName { + return node, nil + } + } + + return nil, fmt.Errorf("%s: document '%s' in dataset '%s'", ErrNotFound, docName, datasetName) +} + +func (p *DatasetProvider) searchDocuments(ctx stdctx.Context, datasetName string, opts *SearchOptions) (*Result, error) { + // If no query is provided, just list documents + if opts.Query == "" { + return p.listDocuments(ctx, datasetName, &ListOptions{ + Limit: opts.Limit, + Offset: opts.Offset, + }) + } + + // Use retrieval API for semantic search in specific dataset + ds, err := p.getDataset(ctx, datasetName) + if err != nil { + return nil, err + } + + kbID := getString(ds.Metadata["id"]) + if kbID == "" { + return nil, fmt.Errorf("dataset ID not found for '%s'", datasetName) + } + + // Build kb_id -> dataset name mapping + kbIDToName := map[string]string{kbID: datasetName} + + // Build retrieval request for specific dataset + payload := map[string]interface{}{ + "kb_id": []string{kbID}, + "question": opts.Query, + } + + // Set top_k (default to 10 if not specified) + topK := opts.TopK + if topK <= 0 { + topK = 10 + } + payload["top_k"] = topK + + // Set similarity threshold (default to 0.2 if not specified to match UI behavior) + threshold := opts.Threshold + if threshold <= 0 { + threshold = 0.2 + } + payload["similarity_threshold"] = threshold + + // Call retrieval API (useAPIBase=false because the route is /v1/chunk/retrieval_test, not /api/v1/...) + resp, err := p.httpClient.Request("POST", "/chunk/retrieval_test", false, "auto", nil, payload) + if err != nil { + return nil, fmt.Errorf("retrieval request failed: %w", err) + } + + var apiResp struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return nil, err + } + + if apiResp.Code != 0 { + return nil, fmt.Errorf("API error: %s", apiResp.Message) + } + + // Parse chunks from response + var nodes []*Node + if chunksData, ok := apiResp.Data["chunks"].([]interface{}); ok { + for _, chunk := range chunksData { + if chunkMap, ok := chunk.(map[string]interface{}); ok { + node := p.chunkToNodeWithKBMapping(chunkMap, kbIDToName) + nodes = append(nodes, node) + } + } + } + + // Apply top_k limit if specified (API may return more results) + if topK > 0 && len(nodes) > topK { + nodes = nodes[:topK] + } + + return &Result{ + Nodes: nodes, + Total: len(nodes), + }, nil +} + +// ==================== Helper Functions ==================== + +func (p *DatasetProvider) datasetToNode(ds map[string]interface{}) *Node { + name := getString(ds["name"]) + node := &Node{ + Name: name, + Path: "/datasets/" + name, + Type: NodeTypeDirectory, + Metadata: ds, + } + + // Parse timestamps - try multiple field names + if createTime, ok := ds["create_time"]; ok && createTime != nil { + node.CreatedAt = parseTime(createTime) + } else if createDate, ok := ds["create_date"]; ok && createDate != nil { + node.CreatedAt = parseTime(createDate) + } + + if updateTime, ok := ds["update_time"]; ok && updateTime != nil { + node.UpdatedAt = parseTime(updateTime) + } else if updateDate, ok := ds["update_date"]; ok && updateDate != nil { + node.UpdatedAt = parseTime(updateDate) + } + + return node +} + +func (p *DatasetProvider) documentToNode(doc map[string]interface{}, datasetName string) *Node { + name := getString(doc["name"]) + node := &Node{ + Name: name, + Path: "datasets/" + datasetName + "/" + name, + Type: NodeTypeDocument, + Metadata: doc, + } + + // Parse size + if size, ok := doc["size"]; ok { + node.Size = int64(getFloat(size)) + } + + // Parse timestamps + if createTime, ok := doc["create_time"]; ok { + node.CreatedAt = parseTime(createTime) + } + if updateTime, ok := doc["update_time"]; ok { + node.UpdatedAt = parseTime(updateTime) + } + + return node +} + +func getString(v interface{}) string { + if v == nil { + return "" + } + if s, ok := v.(string); ok { + return s + } + return fmt.Sprintf("%v", v) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func getFloat(v interface{}) float64 { + if v == nil { + return 0 + } + switch val := v.(type) { + case float64: + return val + case float32: + return float64(val) + case int: + return float64(val) + case int64: + return float64(val) + default: + return 0 + } +} + +func parseTime(v interface{}) time.Time { + if v == nil { + return time.Time{} + } + + var ts int64 + switch val := v.(type) { + case float64: + ts = int64(val) + case int64: + ts = val + case int: + ts = int64(val) + case string: + // Trim quotes if present + val = strings.Trim(val, `"`) + // Try to parse as number (timestamp) + if parsed, err := strconv.ParseInt(val, 10, 64); err == nil { + ts = parsed + } else { + // If it's already a formatted date string, try parsing it + formats := []string{ + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02T15:04:05Z", + "2006-01-02", + } + for _, format := range formats { + if t, err := time.Parse(format, val); err == nil { + return t + } + } + return time.Time{} + } + default: + return time.Time{} + } + + // Convert milliseconds to seconds if timestamp is in milliseconds (13 digits) + if ts > 1e12 { + ts = ts / 1000 + } + + return time.Unix(ts, 0) +} diff --git a/internal/cli/contextengine/engine.go b/internal/cli/contextengine/engine.go new file mode 100644 index 000000000..9f34aa920 --- /dev/null +++ b/internal/cli/contextengine/engine.go @@ -0,0 +1,312 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package contextengine + +import ( + stdctx "context" + "fmt" + "strings" + "time" +) + +// Engine is the core of the Context Engine +// It manages providers and routes commands to the appropriate provider +type Engine struct { + providers []Provider +} + +// NewEngine creates a new Context Engine +func NewEngine() *Engine { + return &Engine{ + providers: make([]Provider, 0), + } +} + +// RegisterProvider registers a provider with the engine +func (e *Engine) RegisterProvider(provider Provider) { + e.providers = append(e.providers, provider) +} + +// GetProviders returns all registered providers +func (e *Engine) GetProviders() []ProviderInfo { + infos := make([]ProviderInfo, 0, len(e.providers)) + for _, p := range e.providers { + infos = append(infos, ProviderInfo{ + Name: p.Name(), + Description: p.Description(), + }) + } + return infos +} + +// Execute executes a command and returns the result +func (e *Engine) Execute(ctx stdctx.Context, cmd *Command) (*Result, error) { + switch cmd.Type { + case CommandList: + return e.List(ctx, cmd.Path, parseListOptions(cmd.Params)) + case CommandSearch: + return e.Search(ctx, cmd.Path, parseSearchOptions(cmd.Params)) + case CommandCat: + _, err := e.Cat(ctx, cmd.Path) + return nil, err + default: + return nil, fmt.Errorf("unknown command type: %s", cmd.Type) + } +} + +// resolveProvider finds the provider for a given path +func (e *Engine) resolveProvider(path string) (Provider, string, error) { + path = normalizePath(path) + + for _, provider := range e.providers { + if provider.Supports(path) { + // Parse the subpath relative to the provider root + // Get provider name to calculate subPath + providerName := provider.Name() + var subPath string + if path == providerName { + subPath = "" + } else if strings.HasPrefix(path, providerName+"/") { + subPath = path[len(providerName)+1:] + } else { + subPath = path + } + return provider, subPath, nil + } + } + + // If no provider supports this path, check if FileProvider can handle it as a fallback + // This allows paths like "myskills" to be treated as "files/myskills" + if fileProvider := e.getFileProvider(); fileProvider != nil { + // Check if the path looks like a file manager path (single component, not matching other providers) + parts := SplitPath(path) + if len(parts) > 0 && parts[0] != "datasets" { + return fileProvider, path, nil + } + } + + return nil, "", fmt.Errorf("%s: %s", ErrProviderNotFound, path) +} + +// List lists nodes at the given path +// If path is empty, returns: +// 1. Built-in providers (e.g., datasets) +// 2. Top-level directories from files provider (if any) +func (e *Engine) List(ctx stdctx.Context, path string, opts *ListOptions) (*Result, error) { + // Normalize path + path = normalizePath(path) + + // If path is empty, return list of providers and files root directories + if path == "" || path == "/" { + return e.listRoot(ctx, opts) + } + + provider, subPath, err := e.resolveProvider(path) + if err != nil { + // If not found, try to find in files provider as a fallback + // This allows "ls myfolder" to work as "ls files/myfolder" + if fileProvider := e.getFileProvider(); fileProvider != nil { + result, ferr := fileProvider.List(ctx, path, opts) + if ferr == nil { + return result, nil + } + } + return nil, err + } + + return provider.List(ctx, subPath, opts) +} + +// listRoot returns the root listing: +// 1. Built-in providers (datasets, etc.) +// 2. Top-level folders from files provider (file_manager) +func (e *Engine) listRoot(ctx stdctx.Context, opts *ListOptions) (*Result, error) { + nodes := make([]*Node, 0) + + // Add built-in providers first (like datasets) + for _, p := range e.providers { + // Skip files provider from this list - we'll add its children instead + if p.Name() == "files" { + continue + } + nodes = append(nodes, &Node{ + Name: p.Name(), + Path: "/" + p.Name(), + Type: NodeTypeDirectory, + CreatedAt: time.Now(), + Metadata: map[string]interface{}{ + "description": p.Description(), + }, + }) + } + + // Add top-level folders from files provider (file_manager) + if fileProvider := e.getFileProvider(); fileProvider != nil { + filesResult, err := fileProvider.List(ctx, "", opts) + if err == nil { + for _, node := range filesResult.Nodes { + // Only add folders (directories), not files + if node.Type == NodeTypeDirectory { + // Ensure path doesn't have /files/ prefix for display + node.Path = strings.TrimPrefix(node.Path, "files/") + node.Path = strings.TrimPrefix(node.Path, "/") + nodes = append(nodes, node) + } + } + } + } + + return &Result{ + Nodes: nodes, + Total: len(nodes), + }, nil +} + +// getFileProvider returns the files provider if registered +func (e *Engine) getFileProvider() Provider { + for _, p := range e.providers { + if p.Name() == "files" { + return p + } + } + return nil +} + +// Search searches for nodes matching the query +func (e *Engine) Search(ctx stdctx.Context, path string, opts *SearchOptions) (*Result, error) { + provider, subPath, err := e.resolveProvider(path) + if err != nil { + return nil, err + } + + return provider.Search(ctx, subPath, opts) +} + +// Cat retrieves the content of a file/document +func (e *Engine) Cat(ctx stdctx.Context, path string) ([]byte, error) { + provider, subPath, err := e.resolveProvider(path) + if err != nil { + // If not found, try to find in files provider as a fallback + // This allows "cat myfolder/file.txt" to work as "cat files/myfolder/file.txt" + if fileProvider := e.getFileProvider(); fileProvider != nil { + return fileProvider.Cat(ctx, path) + } + return nil, err + } + + return provider.Cat(ctx, subPath) +} + +// ParsePath parses a path and returns path information +func (e *Engine) ParsePath(path string) (*PathInfo, error) { + path = normalizePath(path) + components := SplitPath(path) + + if len(components) == 0 { + return nil, fmt.Errorf("empty path") + } + + providerName := components[0] + isRoot := len(components) == 1 + + // Find the provider + var provider Provider + for _, p := range e.providers { + if p.Name() == providerName || strings.HasPrefix(path, p.Name()) { + provider = p + break + } + } + + if provider == nil { + return nil, fmt.Errorf("%s: %s", ErrProviderNotFound, path) + } + + info := &PathInfo{ + Provider: providerName, + Path: path, + Components: components, + IsRoot: isRoot, + } + + // Extract resource ID or name if available + if len(components) >= 2 { + info.ResourceName = components[1] + } + + return info, nil +} + +// parseListOptions parses command params into ListOptions +func parseListOptions(params map[string]interface{}) *ListOptions { + opts := &ListOptions{} + + if params == nil { + return opts + } + + if recursive, ok := params["recursive"].(bool); ok { + opts.Recursive = recursive + } + if limit, ok := params["limit"].(int); ok { + opts.Limit = limit + } + if offset, ok := params["offset"].(int); ok { + opts.Offset = offset + } + if sortBy, ok := params["sort_by"].(string); ok { + opts.SortBy = sortBy + } + if sortOrder, ok := params["sort_order"].(string); ok { + opts.SortOrder = sortOrder + } + + return opts +} + +// parseSearchOptions parses command params into SearchOptions +func parseSearchOptions(params map[string]interface{}) *SearchOptions { + opts := &SearchOptions{} + + if params == nil { + return opts + } + + if query, ok := params["query"].(string); ok { + opts.Query = query + } + if limit, ok := params["limit"].(int); ok { + opts.Limit = limit + } + if offset, ok := params["offset"].(int); ok { + opts.Offset = offset + } + if recursive, ok := params["recursive"].(bool); ok { + opts.Recursive = recursive + } + if topK, ok := params["top_k"].(int); ok { + opts.TopK = topK + } + if threshold, ok := params["threshold"].(float64); ok { + opts.Threshold = threshold + } + if dirs, ok := params["dirs"].([]string); ok { + opts.Dirs = dirs + } + + return opts +} diff --git a/internal/cli/contextengine/file_provider.go b/internal/cli/contextengine/file_provider.go new file mode 100644 index 000000000..b813cbac5 --- /dev/null +++ b/internal/cli/contextengine/file_provider.go @@ -0,0 +1,594 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package contextengine + +import ( + stdctx "context" + "encoding/json" + "fmt" + "strings" +) + +// FileProvider handles file operations using Python backend /files API +// Path structure: +// - files/ -> List root folder contents +// - files/{folder_name}/ -> List folder contents +// - files/{folder_name}/{file_name} -> Get file info/content +// +// Note: Uses Python backend API (useAPIBase=true): +// - GET /files?parent_id={id} -> List files/folders in parent +// - GET /files/{file_id} -> Get file info +// - POST /files -> Create folder or upload file +// - DELETE /files -> Delete files +// - GET /files/{file_id}/parent -> Get parent folder +// - GET /files/{file_id}/ancestors -> Get ancestor folders + +type FileProvider struct { + BaseProvider + httpClient HTTPClientInterface + folderCache map[string]string // path -> folder ID cache + rootID string // root folder ID +} + +// NewFileProvider creates a new FileProvider +func NewFileProvider(httpClient HTTPClientInterface) *FileProvider { + return &FileProvider{ + BaseProvider: BaseProvider{ + name: "files", + description: "File manager provider (Python server)", + rootPath: "files", + }, + httpClient: httpClient, + folderCache: make(map[string]string), + } +} + +// Supports returns true if this provider can handle the given path +func (p *FileProvider) Supports(path string) bool { + normalized := normalizePath(path) + return normalized == "files" || strings.HasPrefix(normalized, "files/") +} + +// List lists nodes at the given path +// Path structure: files/ or files/{folder_name}/ or files/{folder_name}/{sub_path}/... +func (p *FileProvider) List(ctx stdctx.Context, subPath string, opts *ListOptions) (*Result, error) { + // subPath is the path relative to "files/" + // Empty subPath means list root folder + + if subPath == "" { + return p.listRootFolder(ctx, opts) + } + + parts := SplitPath(subPath) + if len(parts) == 1 { + // files/{folder_name} - list contents of this folder + return p.listFolderByName(ctx, parts[0], opts) + } + + // For multi-level paths like myskills/skill-name/dir1, recursively traverse + return p.listPathRecursive(ctx, parts, opts) +} + +// listPathRecursive recursively traverses the path and lists the final component +func (p *FileProvider) listPathRecursive(ctx stdctx.Context, parts []string, opts *ListOptions) (*Result, error) { + if len(parts) == 0 { + return nil, fmt.Errorf("empty path") + } + + // Start from root to find the first folder + currentFolderID, err := p.getFolderIDByName(ctx, parts[0]) + if err != nil { + return nil, err + } + currentPath := parts[0] + + // Traverse through intermediate directories + for i := 1; i < len(parts); i++ { + partName := parts[i] + + // List contents of current folder to find the next part + result, err := p.listFilesByParentID(ctx, currentFolderID, currentPath, nil) + if err != nil { + return nil, err + } + + // Find the next component + found := false + for _, node := range result.Nodes { + if node.Name == partName { + if i == len(parts)-1 { + // This is the last component - if it's a directory, list its contents + if node.Type == NodeTypeDirectory { + childID := getString(node.Metadata["id"]) + if childID == "" { + return nil, fmt.Errorf("folder ID not found for '%s'", partName) + } + newPath := currentPath + "/" + partName + p.folderCache[newPath] = childID + return p.listFilesByParentID(ctx, childID, newPath, opts) + } + // It's a file - return the file node + return &Result{ + Nodes: []*Node{node}, + Total: 1, + }, nil + } + // Not the last component - must be a directory + if node.Type != NodeTypeDirectory { + return nil, fmt.Errorf("'%s' is not a directory", partName) + } + childID := getString(node.Metadata["id"]) + if childID == "" { + return nil, fmt.Errorf("folder ID not found for '%s'", partName) + } + currentFolderID = childID + currentPath = currentPath + "/" + partName + p.folderCache[currentPath] = currentFolderID + found = true + break + } + } + + if !found { + return nil, fmt.Errorf("%s: '%s' in '%s'", ErrNotFound, partName, currentPath) + } + } + + // Should have returned in the loop, but just in case + return p.listFilesByParentID(ctx, currentFolderID, currentPath, opts) +} + +// Search searches for files/folders +func (p *FileProvider) Search(ctx stdctx.Context, subPath string, opts *SearchOptions) (*Result, error) { + if opts.Query == "" { + return p.List(ctx, subPath, &ListOptions{ + Limit: opts.Limit, + Offset: opts.Offset, + }) + } + + // For now, search is not implemented - just list and filter by name + result, err := p.List(ctx, subPath, &ListOptions{ + Limit: opts.Limit, + Offset: opts.Offset, + }) + if err != nil { + return nil, err + } + + // Simple name filtering + var filtered []*Node + query := strings.ToLower(opts.Query) + for _, node := range result.Nodes { + if strings.Contains(strings.ToLower(node.Name), query) { + filtered = append(filtered, node) + } + } + + return &Result{ + Nodes: filtered, + Total: len(filtered), + }, nil +} + +// Cat retrieves file content +func (p *FileProvider) Cat(ctx stdctx.Context, subPath string) ([]byte, error) { + if subPath == "" { + return nil, fmt.Errorf("cat requires a file path: files/{folder}/{file}") + } + + parts := SplitPath(subPath) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid path format, expected: files/{folder}/{file}") + } + + // Find the file by recursively traversing the path + node, err := p.findNodeByPath(ctx, parts) + if err != nil { + return nil, err + } + + if node.Type == NodeTypeDirectory { + return nil, fmt.Errorf("'%s' is a directory, not a file", subPath) + } + + fileID := getString(node.Metadata["id"]) + if fileID == "" { + return nil, fmt.Errorf("file ID not found") + } + + // Download file content + return p.downloadFile(ctx, fileID) +} + +// findNodeByPath recursively traverses the path to find the target node +func (p *FileProvider) findNodeByPath(ctx stdctx.Context, parts []string) (*Node, error) { + if len(parts) == 0 { + return nil, fmt.Errorf("empty path") + } + + // Start from root to find the first folder + currentFolderID, err := p.getFolderIDByName(ctx, parts[0]) + if err != nil { + return nil, err + } + currentPath := parts[0] + + // Traverse through intermediate directories + for i := 1; i < len(parts); i++ { + partName := parts[i] + + // List contents of current folder to find the next part + result, err := p.listFilesByParentID(ctx, currentFolderID, currentPath, nil) + if err != nil { + return nil, err + } + + // Find the next component + found := false + for _, node := range result.Nodes { + if node.Name == partName { + if i == len(parts)-1 { + // This is the last component - return it + return node, nil + } + // Not the last component - must be a directory + if node.Type != NodeTypeDirectory { + return nil, fmt.Errorf("'%s' is not a directory", partName) + } + childID := getString(node.Metadata["id"]) + if childID == "" { + return nil, fmt.Errorf("folder ID not found for '%s'", partName) + } + currentFolderID = childID + currentPath = currentPath + "/" + partName + p.folderCache[currentPath] = currentFolderID + found = true + break + } + } + + if !found { + return nil, fmt.Errorf("%s: '%s' in '%s'", ErrNotFound, partName, currentPath) + } + } + + return nil, fmt.Errorf("%s: '%s'", ErrNotFound, strings.Join(parts, "/")) +} + +// ==================== Python Server API Methods ==================== + +// getRootID gets or caches the root folder ID +func (p *FileProvider) getRootID(ctx stdctx.Context) (string, error) { + if p.rootID != "" { + return p.rootID, nil + } + + // List files without parent_id to get root folder + resp, err := p.httpClient.Request("GET", "/files", true, "auto", nil, nil) + if err != nil { + return "", err + } + + var apiResp struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return "", err + } + + if apiResp.Code != 0 { + return "", fmt.Errorf("API error: %s", apiResp.Message) + } + + // Try to find root folder ID from response + if rootID, ok := apiResp.Data["root_id"].(string); ok && rootID != "" { + p.rootID = rootID + return rootID, nil + } + + // If no explicit root_id, use empty parent_id for root listing + return "", nil +} + +// listRootFolder lists the contents of root folder +func (p *FileProvider) listRootFolder(ctx stdctx.Context, opts *ListOptions) (*Result, error) { + // Get root folder ID first + rootID, err := p.getRootID(ctx) + if err != nil { + return nil, err + } + // List files using root folder ID as parent + return p.listFilesByParentID(ctx, rootID, "", opts) +} + +// listFilesByParentID lists files/folders by parent ID +func (p *FileProvider) listFilesByParentID(ctx stdctx.Context, parentID string, parentPath string, opts *ListOptions) (*Result, error) { + // Build query parameters + queryParams := make([]string, 0) + if parentID != "" { + queryParams = append(queryParams, fmt.Sprintf("parent_id=%s", parentID)) + } + // Always set page=1 and page_size to ensure we get results + pageSize := 100 + if opts != nil && opts.Limit > 0 { + pageSize = opts.Limit + } + queryParams = append(queryParams, fmt.Sprintf("page_size=%d", pageSize)) + queryParams = append(queryParams, "page=1") + + // Build URL with query string + path := "/files" + if len(queryParams) > 0 { + path = path + "?" + strings.Join(queryParams, "&") + } + + resp, err := p.httpClient.Request("GET", path, true, "auto", nil, nil) + if err != nil { + return nil, err + } + + var apiResp struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return nil, err + } + + if apiResp.Code != 0 { + return nil, fmt.Errorf("API error: %s", apiResp.Message) + } + + // Extract files list from data - API returns {"total": N, "files": [...], "parent_folder": {...}} + var files []map[string]interface{} + if fileList, ok := apiResp.Data["files"].([]interface{}); ok { + for _, f := range fileList { + if fileMap, ok := f.(map[string]interface{}); ok { + files = append(files, fileMap) + } + } + } + + nodes := make([]*Node, 0, len(files)) + for _, f := range files { + name := getString(f["name"]) + // Skip hidden .knowledgebase folder + if strings.TrimSpace(name) == ".knowledgebase" { + continue + } + + node := p.fileToNode(f, parentPath) + nodes = append(nodes, node) + + // Cache folder ID + if node.Type == NodeTypeDirectory || getString(f["type"]) == "folder" { + if id := getString(f["id"]); id != "" { + cacheKey := node.Name + if parentPath != "" { + cacheKey = parentPath + "/" + node.Name + } + p.folderCache[cacheKey] = id + } + } + } + + return &Result{ + Nodes: nodes, + Total: len(nodes), + }, nil +} + +// listFolderByName lists contents of a folder by its name +func (p *FileProvider) listFolderByName(ctx stdctx.Context, folderName string, opts *ListOptions) (*Result, error) { + folderID, err := p.getFolderIDByName(ctx, folderName) + if err != nil { + return nil, err + } + + // List files in the folder using folder ID as parent_id + return p.listFilesByParentID(ctx, folderID, folderName, opts) +} + +// getFolderIDByName finds folder ID by its name in root +func (p *FileProvider) getFolderIDByName(ctx stdctx.Context, folderName string) (string, error) { + // Check cache first + if id, ok := p.folderCache[folderName]; ok { + return id, nil + } + + // List root folder to find the folder + rootID, _ := p.getRootID(ctx) + queryParams := make([]string, 0) + if rootID != "" { + queryParams = append(queryParams, fmt.Sprintf("parent_id=%s", rootID)) + } + queryParams = append(queryParams, "page_size=100", "page=1") + + path := "/files" + if len(queryParams) > 0 { + path = path + "?" + strings.Join(queryParams, "&") + } + + resp, err := p.httpClient.Request("GET", path, true, "auto", nil, nil) + if err != nil { + return "", err + } + + var apiResp struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + } + + if err := json.Unmarshal(resp.Body, &apiResp); err != nil { + return "", err + } + + if apiResp.Code != 0 { + return "", fmt.Errorf("API error: %s", apiResp.Message) + } + + // Search for folder by name + var files []map[string]interface{} + if fileList, ok := apiResp.Data["files"].([]interface{}); ok { + for _, f := range fileList { + if fileMap, ok := f.(map[string]interface{}); ok { + files = append(files, fileMap) + } + } + } else if fileList, ok := apiResp.Data["docs"].([]interface{}); ok { + for _, f := range fileList { + if fileMap, ok := f.(map[string]interface{}); ok { + files = append(files, fileMap) + } + } + } + + for _, f := range files { + name := getString(f["name"]) + fileType := getString(f["type"]) + id := getString(f["id"]) + // Match by name and ensure it's a folder + if name == folderName && fileType == "folder" && id != "" { + p.folderCache[folderName] = id + return id, nil + } + } + + return "", fmt.Errorf("%s: folder '%s'", ErrNotFound, folderName) +} + +// getFileNode gets a file node by folder and file name +// If fileName is a directory, returns the directory contents instead of the directory node +func (p *FileProvider) getFileNode(ctx stdctx.Context, folderName, fileName string) (*Result, error) { + folderID, err := p.getFolderIDByName(ctx, folderName) + if err != nil { + return nil, err + } + + // List files in folder to find the file + result, err := p.listFilesByParentID(ctx, folderID, folderName, nil) + if err != nil { + return nil, err + } + + // Find the specific file + for _, node := range result.Nodes { + if node.Name == fileName { + // If it's a directory, list its contents instead of returning the node itself + if node.Type == NodeTypeDirectory { + childFolderID := getString(node.Metadata["id"]) + if childFolderID == "" { + return nil, fmt.Errorf("folder ID not found for '%s'", fileName) + } + // Cache the folder ID + cacheKey := folderName + "/" + fileName + p.folderCache[cacheKey] = childFolderID + // Return directory contents + return p.listFilesByParentID(ctx, childFolderID, cacheKey, nil) + } + // Return file node + return &Result{ + Nodes: []*Node{node}, + Total: 1, + }, nil + } + } + + return nil, fmt.Errorf("%s: file '%s' in folder '%s'", ErrNotFound, fileName, folderName) +} + +// downloadFile downloads file content +func (p *FileProvider) downloadFile(ctx stdctx.Context, fileID string) ([]byte, error) { + path := fmt.Sprintf("/files/%s", fileID) + resp, err := p.httpClient.Request("GET", path, true, "auto", nil, nil) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + // Try to parse error response + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body, &apiResp); err == nil && apiResp.Code != 0 { + return nil, fmt.Errorf("%s", apiResp.Message) + } + return nil, fmt.Errorf("HTTP error %d", resp.StatusCode) + } + + // Return raw file content + return resp.Body, nil +} + +// ==================== Conversion Functions ==================== + +// fileToNode converts a file map to a Node +func (p *FileProvider) fileToNode(f map[string]interface{}, parentPath string) *Node { + name := getString(f["name"]) + fileType := getString(f["type"]) + fileID := getString(f["id"]) + + // Determine node type + nodeType := NodeTypeFile + if fileType == "folder" { + nodeType = NodeTypeDirectory + } + + // Build path + path := name + if parentPath != "" { + path = parentPath + "/" + name + } + + node := &Node{ + Name: name, + Path: path, + Type: nodeType, + Metadata: f, + } + + // Parse size + if size, ok := f["size"]; ok { + node.Size = int64(getFloat(size)) + } + + // Parse timestamps + if createTime, ok := f["create_time"]; ok && createTime != nil { + node.CreatedAt = parseTime(createTime) + } + if updateTime, ok := f["update_time"]; ok && updateTime != nil { + node.UpdatedAt = parseTime(updateTime) + } + + // Store ID for later use + if fileID != "" { + if node.Metadata == nil { + node.Metadata = make(map[string]interface{}) + } + node.Metadata["id"] = fileID + } + + return node +} diff --git a/internal/cli/contextengine/provider.go b/internal/cli/contextengine/provider.go new file mode 100644 index 000000000..605a39b89 --- /dev/null +++ b/internal/cli/contextengine/provider.go @@ -0,0 +1,180 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package contextengine + +import ( + stdctx "context" +) + +// Provider is the interface for all context providers +// Each provider handles a specific resource type (datasets, chats, agents, etc.) +type Provider interface { + // Name returns the provider name (e.g., "datasets", "chats") + Name() string + + // Description returns a human-readable description of the provider + Description() string + + // Supports returns true if this provider can handle the given path + Supports(path string) bool + + // List lists nodes at the given path + List(ctx stdctx.Context, path string, opts *ListOptions) (*Result, error) + + // Search searches for nodes matching the query under the given path + Search(ctx stdctx.Context, path string, opts *SearchOptions) (*Result, error) + + // Cat retrieves the content of a file/document at the given path + Cat(ctx stdctx.Context, path string) ([]byte, error) +} + +// BaseProvider provides common functionality for all providers +type BaseProvider struct { + name string + description string + rootPath string +} + +// Name returns the provider name +func (p *BaseProvider) Name() string { + return p.name +} + +// Description returns the provider description +func (p *BaseProvider) Description() string { + return p.description +} + +// GetRootPath returns the root path for this provider +func (p *BaseProvider) GetRootPath() string { + return p.rootPath +} + +// IsRootPath checks if the given path is the root path for this provider +func (p *BaseProvider) IsRootPath(path string) bool { + return normalizePath(path) == normalizePath(p.rootPath) +} + +// ParsePath parses a path and returns the subpath relative to the provider root +func (p *BaseProvider) ParsePath(path string) string { + normalized := normalizePath(path) + rootNormalized := normalizePath(p.rootPath) + + if normalized == rootNormalized { + return "" + } + + if len(normalized) > len(rootNormalized) && normalized[:len(rootNormalized)+1] == rootNormalized+"/" { + return normalized[len(rootNormalized)+1:] + } + + return normalized +} + +// SplitPath splits a path into components +func SplitPath(path string) []string { + path = normalizePath(path) + if path == "" { + return []string{} + } + parts := splitString(path, '/') + result := make([]string, 0, len(parts)) + for _, part := range parts { + if part != "" { + result = append(result, part) + } + } + return result +} + +// normalizePath normalizes a path (removes leading/trailing slashes, handles "." and "..") +func normalizePath(path string) string { + path = trimSpace(path) + if path == "" { + return "" + } + + // Remove leading slashes + for len(path) > 0 && path[0] == '/' { + path = path[1:] + } + + // Remove trailing slashes + for len(path) > 0 && path[len(path)-1] == '/' { + path = path[:len(path)-1] + } + + // Handle "." and ".." + parts := splitString(path, '/') + result := make([]string, 0, len(parts)) + for _, part := range parts { + switch part { + case "", ".": + // Skip empty and current directory + continue + case "..": + // Go up one directory + if len(result) > 0 { + result = result[:len(result)-1] + } + default: + result = append(result, part) + } + } + + return joinStrings(result, "/") +} + +// Helper functions to avoid importing strings package in basic operations +func trimSpace(s string) string { + start := 0 + end := len(s) + for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') { + start++ + } + for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') { + end-- + } + return s[start:end] +} + +func splitString(s string, sep byte) []string { + var result []string + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == sep { + result = append(result, s[start:i]) + start = i + 1 + } + } + result = append(result, s[start:]) + return result +} + +func joinStrings(strs []string, sep string) string { + if len(strs) == 0 { + return "" + } + if len(strs) == 1 { + return strs[0] + } + result := strs[0] + for i := 1; i < len(strs); i++ { + result += sep + strs[i] + } + return result +} diff --git a/internal/cli/contextengine/types.go b/internal/cli/contextengine/types.go new file mode 100644 index 000000000..b01777422 --- /dev/null +++ b/internal/cli/contextengine/types.go @@ -0,0 +1,116 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package contextengine + +import "time" + +// NodeType represents the type of a node in the context filesystem +type NodeType string + +const ( + NodeTypeDirectory NodeType = "directory" + NodeTypeFile NodeType = "file" + NodeTypeDataset NodeType = "dataset" + NodeTypeDocument NodeType = "document" + NodeTypeChat NodeType = "chat" + NodeTypeAgent NodeType = "agent" + NodeTypeUnknown NodeType = "unknown" +) + +// Node represents a node in the context filesystem +// This is the unified output format for all providers +type Node struct { + Name string `json:"name"` + Path string `json:"path"` + Type NodeType `json:"type"` + Size int64 `json:"size,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// CommandType represents the type of command +type CommandType string + +const ( + CommandList CommandType = "ls" + CommandSearch CommandType = "search" + CommandCat CommandType = "cat" +) + +// Command represents a context engine command +type Command struct { + Type CommandType `json:"type"` + Path string `json:"path"` + Params map[string]interface{} `json:"params,omitempty"` +} + +// ListOptions represents options for list operations +type ListOptions struct { + Recursive bool `json:"recursive,omitempty"` + Limit int `json:"limit,omitempty"` + Offset int `json:"offset,omitempty"` + SortBy string `json:"sort_by,omitempty"` + SortOrder string `json:"sort_order,omitempty"` // "asc" or "desc" +} + +// SearchOptions represents options for search operations +type SearchOptions struct { + Query string `json:"query"` + Limit int `json:"limit,omitempty"` + Offset int `json:"offset,omitempty"` + Recursive bool `json:"recursive,omitempty"` + TopK int `json:"top_k,omitempty"` // Number of top results to return (default: 10) + Threshold float64 `json:"threshold,omitempty"` // Similarity threshold (default: 0.2) + Dirs []string `json:"dirs,omitempty"` // List of directories to search in +} + +// Result represents the result of a command execution +type Result struct { + Nodes []*Node `json:"nodes"` + Total int `json:"total"` + HasMore bool `json:"has_more"` + NextOffset int `json:"next_offset,omitempty"` + Error error `json:"-"` +} + +// PathInfo represents parsed path information +type PathInfo struct { + Provider string // The provider name (e.g., "datasets", "chats") + Path string // The full path + Components []string // Path components + IsRoot bool // Whether this is the root path for the provider + ResourceID string // Resource ID if applicable + ResourceName string // Resource name if applicable +} + +// ProviderInfo holds metadata about a provider +type ProviderInfo struct { + Name string `json:"name"` + Description string `json:"description"` + RootPath string `json:"root_path"` +} + +// Common error messages +const ( + ErrInvalidPath = "invalid path" + ErrProviderNotFound = "provider not found for path" + ErrNotSupported = "operation not supported" + ErrNotFound = "resource not found" + ErrUnauthorized = "unauthorized" + ErrInternal = "internal error" +) diff --git a/internal/cli/contextengine/utils.go b/internal/cli/contextengine/utils.go new file mode 100644 index 000000000..ca9b7ca98 --- /dev/null +++ b/internal/cli/contextengine/utils.go @@ -0,0 +1,304 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package contextengine + +import ( + "encoding/json" + "fmt" + "time" +) + +// FormatNode formats a node for display +func FormatNode(node *Node, format string) map[string]interface{} { + switch format { + case "json": + return map[string]interface{}{ + "name": node.Name, + "path": node.Path, + "type": string(node.Type), + "size": node.Size, + "created_at": node.CreatedAt.Format(time.RFC3339), + "updated_at": node.UpdatedAt.Format(time.RFC3339), + } + case "table": + return map[string]interface{}{ + "name": node.Name, + "path": node.Path, + "type": string(node.Type), + "size": formatSize(node.Size), + "created_at": formatTime(node.CreatedAt), + "updated_at": formatTime(node.UpdatedAt), + } + default: // "plain" + return map[string]interface{}{ + "name": node.Name, + "path": node.Path, + "type": string(node.Type), + "created_at": formatTime(node.CreatedAt), + "updated_at": formatTime(node.UpdatedAt), + } + } +} + +// FormatNodes formats a list of nodes for display +func FormatNodes(nodes []*Node, format string) []map[string]interface{} { + result := make([]map[string]interface{}, 0, len(nodes)) + for _, node := range nodes { + result = append(result, FormatNode(node, format)) + } + return result +} + +// formatSize formats a size in bytes to human-readable format +func formatSize(size int64) string { + if size == 0 { + return "-" + } + + const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + TB = 1024 * GB + ) + + switch { + case size >= TB: + return fmt.Sprintf("%.2f TB", float64(size)/TB) + case size >= GB: + return fmt.Sprintf("%.2f GB", float64(size)/GB) + case size >= MB: + return fmt.Sprintf("%.2f MB", float64(size)/MB) + case size >= KB: + return fmt.Sprintf("%.2f KB", float64(size)/KB) + default: + return fmt.Sprintf("%d B", size) + } +} + +// formatTime formats a time to a readable string +func formatTime(t time.Time) string { + if t.IsZero() { + return "-" + } + return t.Format("2006-01-02 15:04:05") +} + +// ResultToMap converts a Result to a map for JSON serialization +func ResultToMap(result *Result) map[string]interface{} { + if result == nil { + return map[string]interface{}{ + "nodes": []interface{}{}, + "total": 0, + } + } + + nodes := make([]map[string]interface{}, 0, len(result.Nodes)) + for _, node := range result.Nodes { + nodes = append(nodes, nodeToMap(node)) + } + + return map[string]interface{}{ + "nodes": nodes, + "total": result.Total, + "has_more": result.HasMore, + "next_offset": result.NextOffset, + } +} + +// nodeToMap converts a Node to a map +func nodeToMap(node *Node) map[string]interface{} { + m := map[string]interface{}{ + "name": node.Name, + "path": node.Path, + "type": string(node.Type), + } + + if node.Size > 0 { + m["size"] = node.Size + } + + if !node.CreatedAt.IsZero() { + m["created_at"] = node.CreatedAt.Format(time.RFC3339) + } + + if !node.UpdatedAt.IsZero() { + m["updated_at"] = node.UpdatedAt.Format(time.RFC3339) + } + + if len(node.Metadata) > 0 { + m["metadata"] = node.Metadata + } + + return m +} + +// MarshalJSON marshals a Result to JSON bytes +func (r *Result) MarshalJSON() ([]byte, error) { + return json.Marshal(ResultToMap(r)) +} + +// PrintResult prints a result in the specified format +func PrintResult(result *Result, format string) { + if result == nil { + fmt.Println("No results") + return + } + + switch format { + case "json": + data, _ := json.MarshalIndent(ResultToMap(result), "", " ") + fmt.Println(string(data)) + case "table": + printTable(result.Nodes) + default: // "plain" + for _, node := range result.Nodes { + fmt.Println(node.Path) + } + } +} + +// printTable prints nodes in a simple table format +func printTable(nodes []*Node) { + if len(nodes) == 0 { + fmt.Println("No results") + return + } + + // Print header + fmt.Printf("%-40s %-12s %-12s %-20s %-20s\n", "NAME", "TYPE", "SIZE", "CREATED", "UPDATED") + fmt.Println(string(make([]byte, 104))) + + // Print rows + for _, node := range nodes { + fmt.Printf("%-40s %-12s %-12s %-20s %-20s\n", + truncateString(node.Name, 40), + node.Type, + formatSize(node.Size), + formatTime(node.CreatedAt), + formatTime(node.UpdatedAt), + ) + } +} + +// truncateString truncates a string to the specified length +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} + +// IsValidPath checks if a path is valid +func IsValidPath(path string) bool { + if path == "" { + return false + } + + // Check for invalid characters + invalidChars := []string{"..", "//", "\\", "*", "?", "<", ">", "|", "\x00"} + for _, char := range invalidChars { + if containsString(path, char) { + return false + } + } + + return true +} + +// containsString checks if a string contains a substring +func containsString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// JoinPath joins path components +func JoinPath(components ...string) string { + if len(components) == 0 { + return "" + } + + result := components[0] + for i := 1; i < len(components); i++ { + if result == "" { + result = components[i] + } else if components[i] == "" { + continue + } else { + // Remove trailing slash from result + for len(result) > 0 && result[len(result)-1] == '/' { + result = result[:len(result)-1] + } + // Remove leading slash from component + start := 0 + for start < len(components[i]) && components[i][start] == '/' { + start++ + } + result = result + "/" + components[i][start:] + } + } + + return result +} + +// GetParentPath returns the parent path of a given path +func GetParentPath(path string) string { + path = normalizePath(path) + parts := SplitPath(path) + + if len(parts) <= 1 { + return "" + } + + return joinStrings(parts[:len(parts)-1], "/") +} + +// GetBaseName returns the last component of a path +func GetBaseName(path string) string { + path = normalizePath(path) + parts := SplitPath(path) + + if len(parts) == 0 { + return "" + } + + return parts[len(parts)-1] +} + +// HasPrefix checks if a path has the given prefix +func HasPrefix(path, prefix string) bool { + path = normalizePath(path) + prefix = normalizePath(prefix) + + if prefix == "" { + return true + } + + if path == prefix { + return true + } + + if len(path) > len(prefix) && path[:len(prefix)+1] == prefix+"/" { + return true + } + + return false +} diff --git a/internal/cli/parser.go b/internal/cli/parser.go index 0c15345ab..769c0b6c9 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -55,6 +55,11 @@ func (p *Parser) Parse(adminCommand bool) (*Command, error) { return p.parseMetaCommand() } + // Check for ContextEngine commands (ls, cat, search) + if p.curToken.Type == TokenIdentifier && isCECommand(p.curToken.Value) { + return p.parseCECommand() + } + // Parse SQL-like command return p.parseSQLCommand(adminCommand) } @@ -215,6 +220,16 @@ func isKeyword(tokenType int) bool { return tokenType >= TokenLogin && tokenType <= TokenDocMeta } +// isCECommand checks if the given string is a ContextEngine command +func isCECommand(s string) bool { + upper := strings.ToUpper(s) + switch upper { + case "LS", "LIST", "SEARCH": + return true + } + return false +} + // Helper functions for parsing func (p *Parser) parseQuotedString() (string, error) { if p.curToken.Type != TokenQuotedString { @@ -241,3 +256,92 @@ func tokenTypeToString(t int) string { // Simplified for error messages return fmt.Sprintf("token(%d)", t) } + +// parseCECommand parses ContextEngine commands (ls, search) +func (p *Parser) parseCECommand() (*Command, error) { + cmdName := strings.ToUpper(p.curToken.Value) + + switch cmdName { + case "LS", "LIST": + return p.parseCEListCommand() + case "SEARCH": + return p.parseCESearchCommand() + default: + return nil, fmt.Errorf("unknown ContextEngine command: %s", cmdName) + } +} + +// parseCEListCommand parses the ls command +// Syntax: ls [path] or ls datasets +func (p *Parser) parseCEListCommand() (*Command, error) { + p.nextToken() // consume LS/LIST + + cmd := NewCommand("ce_ls") + + // Check if there's a path argument + // Also accept TokenDatasets since "datasets" is a keyword but can be a path + if p.curToken.Type == TokenIdentifier || p.curToken.Type == TokenQuotedString || + p.curToken.Type == TokenDatasets { + path := p.curToken.Value + // Remove quotes if present + if p.curToken.Type == TokenQuotedString { + path = strings.Trim(path, "\"'") + } + cmd.Params["path"] = path + p.nextToken() + } else { + // Default to "datasets" root + cmd.Params["path"] = "datasets" + } + + // Optional semicolon + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + +// parseCESearchCommand parses the search command +// Syntax: search or search in +func (p *Parser) parseCESearchCommand() (*Command, error) { + p.nextToken() // consume SEARCH + + cmd := NewCommand("ce_search") + + if p.curToken.Type != TokenIdentifier && p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expected query after SEARCH") + } + + query := p.curToken.Value + if p.curToken.Type == TokenQuotedString { + query = strings.Trim(query, "\"'") + } + cmd.Params["query"] = query + p.nextToken() + + // Check for optional "in " clause + if p.curToken.Type == TokenIdentifier && strings.ToUpper(p.curToken.Value) == "IN" { + p.nextToken() // consume IN + + if p.curToken.Type != TokenIdentifier && p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expected path after IN") + } + + path := p.curToken.Value + if p.curToken.Type == TokenQuotedString { + path = strings.Trim(path, "\"'") + } + cmd.Params["path"] = path + p.nextToken() + } else { + cmd.Params["path"] = "." + } + + // Optional semicolon + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} diff --git a/internal/cli/table.go b/internal/cli/table.go index 7baef5d5a..18fad5fa1 100644 --- a/internal/cli/table.go +++ b/internal/cli/table.go @@ -17,16 +17,46 @@ package cli import ( + "encoding/json" "fmt" + "strconv" "strings" - "unicode" ) -// PrintTableSimple prints data in a simple table format +const maxColWidth = 256 + +// PrintTableSimple prints data in a simple table format (default: table format with borders) // Similar to Python's _print_table_simple func PrintTableSimple(data []map[string]interface{}) { + PrintTableSimpleByFormat(data, OutputFormatTable) +} + +// PrintTableSimpleByFormat prints data in the specified format +// Supports: table (with borders), plain (no borders, space-separated), json +// - Column names in lowercase +// - Two spaces between columns +// - Numeric columns right-aligned +// - URI/path columns not truncated +func PrintTableSimpleByFormat(data []map[string]interface{}, format OutputFormat) { if len(data) == 0 { - fmt.Println("No data to print") + if format == OutputFormatJSON { + fmt.Println("[]") + } else if format == OutputFormatPlain { + fmt.Println("(empty)") + } else { + fmt.Println("No data to print") + } + return + } + + // JSON format: output as JSON array + if format == OutputFormatJSON { + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + fmt.Printf("Error marshaling JSON: %v\n", err) + return + } + fmt.Println(string(jsonData)) return } @@ -52,60 +82,190 @@ func PrintTableSimple(data []map[string]interface{}) { } } - // Calculate column widths + // Analyze columns: check if numeric and if URI column + colIsNumeric := make(map[string]bool) + colIsURI := make(map[string]bool) + for _, col := range columns { + colLower := strings.ToLower(col) + if colLower == "uri" || colLower == "path" || colLower == "id" { + colIsURI[col] = true + } + // Check if all values are numeric + isNumeric := true + for _, item := range data { + if val, ok := item[col]; ok { + if !isNumericValue(val) { + isNumeric = false + break + } + } + } + colIsNumeric[col] = isNumeric + } + + // Calculate column widths (capped at maxColWidth) colWidths := make(map[string]int) for _, col := range columns { - maxWidth := getStringWidth(col) + maxWidth := getStringWidth(strings.ToLower(col)) for _, item := range data { - value := fmt.Sprintf("%v", item[col]) + value := formatValue(item[col]) valueWidth := getStringWidth(value) if valueWidth > maxWidth { maxWidth = valueWidth } } + if maxWidth > maxColWidth { + maxWidth = maxColWidth + } if maxWidth < 2 { maxWidth = 2 } colWidths[col] = maxWidth } - // Generate separator - separatorParts := make([]string, 0, len(columns)) - for _, col := range columns { - separatorParts = append(separatorParts, strings.Repeat("-", colWidths[col]+2)) - } - separator := "+" + strings.Join(separatorParts, "+") + "+" - - // Print header - fmt.Println(separator) - headerParts := make([]string, 0, len(columns)) - for _, col := range columns { - headerParts = append(headerParts, fmt.Sprintf(" %-*s ", colWidths[col], col)) - } - fmt.Println("|" + strings.Join(headerParts, "|") + "|") - fmt.Println(separator) - - // Print data rows - for _, item := range data { - rowParts := make([]string, 0, len(columns)) + if format == OutputFormatPlain { + // Plain mode: no borders, space-separated (ov CLI compatible) + // Print header (lowercase column names, right-aligned for numeric columns) + headerParts := make([]string, 0, len(columns)) for _, col := range columns { - value := fmt.Sprintf("%v", item[col]) - valueWidth := getStringWidth(value) - // Truncate if too long - if valueWidth > colWidths[col] { - runes := []rune(value) - truncated := truncateString(runes, colWidths[col]) - value = truncated - valueWidth = getStringWidth(value) - } - // Pad to column width - padding := colWidths[col] - valueWidth + len(value) - rowParts = append(rowParts, fmt.Sprintf(" %-*s ", padding, value)) + // Header follows the same alignment as data (right-aligned for numeric columns) + headerParts = append(headerParts, padCell(strings.ToLower(col), colWidths[col], colIsNumeric[col])) } - fmt.Println("|" + strings.Join(rowParts, "|") + "|") - } + fmt.Println(strings.Join(headerParts, " ")) - fmt.Println(separator) + // Print data rows + for _, item := range data { + rowParts := make([]string, 0, len(columns)) + for _, col := range columns { + value := formatValue(item[col]) + isURI := colIsURI[col] + isNumeric := colIsNumeric[col] + + // URI columns: never truncate, no padding if too long + if isURI && getStringWidth(value) > colWidths[col] { + rowParts = append(rowParts, value) + } else { + // Normal cell: truncate if too long, then pad + valueWidth := getStringWidth(value) + if valueWidth > colWidths[col] { + runes := []rune(value) + value = truncateStringByWidth(runes, colWidths[col]) + valueWidth = getStringWidth(value) + } + rowParts = append(rowParts, padCell(value, colWidths[col], isNumeric)) + } + } + fmt.Println(strings.Join(rowParts, " ")) + } + } else { + // Normal mode: with borders + // Generate separator + separatorParts := make([]string, 0, len(columns)) + for _, col := range columns { + separatorParts = append(separatorParts, strings.Repeat("-", colWidths[col]+2)) + } + separator := "+" + strings.Join(separatorParts, "+") + "+" + + // Print header + fmt.Println(separator) + headerParts := make([]string, 0, len(columns)) + for _, col := range columns { + headerParts = append(headerParts, fmt.Sprintf(" %-*s ", colWidths[col], col)) + } + fmt.Println("|" + strings.Join(headerParts, "|") + "|") + fmt.Println(separator) + + // Print data rows + for _, item := range data { + rowParts := make([]string, 0, len(columns)) + for _, col := range columns { + value := formatValue(item[col]) + valueWidth := getStringWidth(value) + // Truncate if too long + if valueWidth > colWidths[col] { + runes := []rune(value) + value = truncateStringByWidth(runes, colWidths[col]) + valueWidth = getStringWidth(value) + } + // Pad to column width + padding := colWidths[col] - valueWidth + len(value) + rowParts = append(rowParts, fmt.Sprintf(" %-*s ", padding, value)) + } + fmt.Println("|" + strings.Join(rowParts, "|") + "|") + } + + fmt.Println(separator) + } +} + +// formatValue formats a value for display +func formatValue(v interface{}) string { + if v == nil { + return "" + } + switch val := v.(type) { + case string: + return val + case int: + return strconv.Itoa(val) + case int64: + return strconv.FormatInt(val, 10) + case float64: + return strconv.FormatFloat(val, 'f', -1, 64) + case bool: + return strconv.FormatBool(val) + default: + return fmt.Sprintf("%v", v) + } +} + +// isNumericValue checks if a value is numeric +func isNumericValue(v interface{}) bool { + if v == nil { + return false + } + switch val := v.(type) { + case int, int8, int16, int32, int64: + return true + case uint, uint8, uint16, uint32, uint64: + return true + case float32, float64: + return true + case string: + _, err := strconv.ParseFloat(val, 64) + return err == nil + default: + return false + } +} + +// truncateStringByWidth truncates a string to fit within maxWidth display width +func truncateStringByWidth(runes []rune, maxWidth int) string { + width := 0 + for i, r := range runes { + if isHalfWidth(r) { + width++ + } else { + width += 2 + } + if width > maxWidth-3 { + return string(runes[:i]) + "..." + } + } + return string(runes) +} + +// padCell pads a string to the specified width for alignment +func padCell(content string, width int, alignRight bool) string { + contentWidth := getStringWidth(content) + if contentWidth >= width { + return content + } + padding := width - contentWidth + if alignRight { + return strings.Repeat(" ", padding) + content + } + return content + strings.Repeat(" ", padding) } // getStringWidth calculates the display width of a string @@ -134,34 +294,4 @@ func isHalfWidth(r rune) bool { return false } -// truncateString truncates a string to fit within maxWidth display width -func truncateString(runes []rune, maxWidth int) string { - width := 0 - for i, r := range runes { - if isHalfWidth(r) { - width++ - } else { - width += 2 - } - if width > maxWidth-3 { - return string(runes[:i]) + "..." - } - } - return string(runes) -} -// getMax returns the maximum of two integers -func getMax(a, b int) int { - if a > b { - return a - } - return b -} - -// isWideChar checks if a character is wide (CJK, etc.) -func isWideChar(r rune) bool { - return unicode.Is(unicode.Han, r) || - unicode.Is(unicode.Hiragana, r) || - unicode.Is(unicode.Katakana, r) || - unicode.Is(unicode.Hangul, r) -} diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index ed108b242..dda574409 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -143,13 +143,19 @@ func (c *RAGFlowClient) ListUserDatasets(cmd *Command) (ResponseIf, error) { iterations = val } + // Determine auth kind based on whether API token is being used + authKind := "web" + if c.HTTPClient.useAPIToken { + authKind = "api" + } + if iterations > 1 { // Benchmark mode - return raw result for benchmark stats - return c.HTTPClient.RequestWithIterations("GET", "/datasets", true, "web", nil, nil, iterations) + return c.HTTPClient.RequestWithIterations("GET", "/datasets", true, authKind, nil, nil, iterations) } // Normal mode - resp, err := c.HTTPClient.Request("GET", "/datasets", true, "web", nil, nil) + resp, err := c.HTTPClient.Request("GET", "/datasets", true, authKind, nil, nil) if err != nil { return nil, fmt.Errorf("failed to list datasets: %w", err) } diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 98d729fff..1f7fd5a05 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -1,6 +1,9 @@ package cli -import "fmt" +import ( + "fmt" + "strconv" +) // Command parsers func (p *Parser) parseLogout() (*Command, error) {