From 8a9bbf3d6d9c6e67a8c2baf56a67a405cc0ab939 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Fri, 27 Mar 2026 09:49:50 +0800 Subject: [PATCH] Feat: add memory function by go (#13754) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Feat: Add Memory function by go ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng --- cmd/server_main.go | 4 +- internal/dao/memory.go | 370 +++++++++++++++ internal/dao/tenant_llm.go | 127 ++++++ internal/handler/memory.go | 687 ++++++++++++++++++++++++++++ internal/model/memory.go | 9 + internal/router/router.go | 26 ++ internal/service/memory.go | 892 +++++++++++++++++++++++++++++++++++++ internal/service/tenant.go | 132 ++++++ internal/service/user.go | 90 ++++ web/.env.development | 3 +- web/vite.config.ts | 72 +-- 11 files changed, 2350 insertions(+), 62 deletions(-) create mode 100644 internal/dao/memory.go create mode 100644 internal/handler/memory.go create mode 100644 internal/service/memory.go diff --git a/cmd/server_main.go b/cmd/server_main.go index 5d0fa5798..fe69e50c7 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -175,6 +175,7 @@ func startServer(config *server.Config) { connectorService := service.NewConnectorService() searchService := service.NewSearchService() fileService := service.NewFileService() + memoryService := service.NewMemoryService() // Initialize handler layer authHandler := handler.NewAuthHandler() @@ -191,9 +192,10 @@ func startServer(config *server.Config) { connectorHandler := handler.NewConnectorHandler(connectorService, userService) searchHandler := handler.NewSearchHandler(searchService, userService) fileHandler := handler.NewFileHandler(fileService, userService) + memoryHandler := handler.NewMemoryHandler(memoryService) // Initialize router - r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler) // Create Gin engine ginEngine := gin.New() diff --git a/internal/dao/memory.go b/internal/dao/memory.go new file mode 100644 index 000000000..35353a77f --- /dev/null +++ b/internal/dao/memory.go @@ -0,0 +1,370 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package dao implements the data access layer +// This file implements Memory-related database operations +// Consistent with Python memory_service.py +package dao + +import ( + "fmt" + "strings" + + "ragflow/internal/model" +) + +// Memory type bit flag constants, consistent with Python MemoryType enum +const ( + MemoryTypeRaw = 0b0001 // Raw memory (binary: 0001) + MemoryTypeSemantic = 0b0010 // Semantic memory (binary: 0010) + MemoryTypeEpisodic = 0b0100 // Episodic memory (binary: 0100) + MemoryTypeProcedural = 0b1000 // Procedural memory (binary: 1000) +) + +// MemoryTypeMap maps memory type names to bit flags +// Exported for use by service package +var MemoryTypeMap = map[string]int{ + "raw": MemoryTypeRaw, + "semantic": MemoryTypeSemantic, + "episodic": MemoryTypeEpisodic, + "procedural": MemoryTypeProcedural, +} + +// CalculateMemoryType converts memory type names array to bit flags integer +// +// Parameters: +// - memoryTypeNames: Memory type names array +// +// Returns: +// - int64: Bit flags integer +// +// Example: +// +// CalculateMemoryType([]string{"raw", "semantic"}) returns 3 (0b0011) +func CalculateMemoryType(memoryTypeNames []string) int64 { + memoryType := 0 + for _, name := range memoryTypeNames { + lowerName := strings.ToLower(name) + if mt, ok := MemoryTypeMap[lowerName]; ok { + memoryType |= mt + } + } + return int64(memoryType) +} + +// GetMemoryTypeHuman converts memory type bit flags to human-readable names +// +// Parameters: +// - memoryType: Bit flags integer representing memory types +// +// Returns: +// - []string: Array of human-readable memory type names +// +// Example: +// +// GetMemoryTypeHuman(3) returns ["raw", "semantic"] +func GetMemoryTypeHuman(memoryType int64) []string { + var result []string + if memoryType&int64(MemoryTypeRaw) != 0 { + result = append(result, "raw") + } + if memoryType&int64(MemoryTypeSemantic) != 0 { + result = append(result, "semantic") + } + if memoryType&int64(MemoryTypeEpisodic) != 0 { + result = append(result, "episodic") + } + if memoryType&int64(MemoryTypeProcedural) != 0 { + result = append(result, "procedural") + } + return result +} + +// MemoryDAO handles all Memory-related database operations +type MemoryDAO struct{} + +// NewMemoryDAO creates a new MemoryDAO instance +// +// Returns: +// - *MemoryDAO: Initialized DAO instance +func NewMemoryDAO() *MemoryDAO { + return &MemoryDAO{} +} + +// Create inserts a new memory record into the database +// +// Parameters: +// - memory: Memory model pointer +// +// Returns: +// - error: Database operation error +func (dao *MemoryDAO) Create(memory *model.Memory) error { + return DB.Create(memory).Error +} + +// GetByID retrieves a memory record by ID from database +// +// Parameters: +// - id: Memory ID +// +// Returns: +// - *model.Memory: Memory model pointer +// - error: Database operation error +func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) { + var memory model.Memory + err := DB.Where("id = ?", id).First(&memory).Error + if err != nil { + return nil, err + } + return &memory, nil +} + +// GetByTenantID retrieves all memories for a tenant +// +// Parameters: +// - tenantID: Tenant ID +// +// Returns: +// - []*model.Memory: Memory model pointer array +// - error: Database operation error +func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) { + var memories []*model.Memory + err := DB.Where("tenant_id = ?", tenantID).Find(&memories).Error + return memories, err +} + +// GetByNameAndTenant checks if memory exists by name and tenant ID +// Used for duplicate name deduplication +// +// Parameters: +// - name: Memory name +// - tenantID: Tenant ID +// +// Returns: +// - []*model.Memory: Matching memory list (for existence check) +// - error: Database operation error +func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model.Memory, error) { + var memories []*model.Memory + err := DB.Where("name = ? AND tenant_id = ?", name, tenantID).Find(&memories).Error + return memories, err +} + +// GetByIDs retrieves memories by multiple IDs +// +// Parameters: +// - ids: Memory ID list +// +// Returns: +// - []*model.Memory: Memory model pointer array +// - error: Database operation error +func (dao *MemoryDAO) GetByIDs(ids []string) ([]*model.Memory, error) { + var memories []*model.Memory + err := DB.Where("id IN ?", ids).Find(&memories).Error + return memories, err +} + +// UpdateByID updates a memory by ID +// Supports partial updates - only updates passed fields +// Automatically handles field type conversions +// +// Parameters: +// - id: Memory ID +// - updates: Fields to update map +// +// Returns: +// - error: Database operation error +// +// Field type handling: +// - memory_type: []string converts to bit flags integer +// - temperature: string converts to float64 +// - name: Uses string value directly +// - permissions, forgetting_policy: Uses string value directly +// +// Example: +// +// updates := map[string]interface{}{"name": "NewName", "memory_type": []string{"semantic"}} +// err := dao.UpdateByID("memory123", updates) +func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) error { + if updates == nil || len(updates) == 0 { + return nil + } + + for key, value := range updates { + switch key { + case "memory_type": + if types, ok := value.([]string); ok { + updates[key] = CalculateMemoryType(types) + } + case "temperature": + if tempStr, ok := value.(string); ok { + var temp float64 + fmt.Sscanf(tempStr, "%f", &temp) + updates[key] = temp + } + } + } + + return DB.Model(&model.Memory{}).Where("id = ?", id).Updates(updates).Error +} + +// DeleteByID deletes a memory by ID +// +// Parameters: +// - id: Memory ID +// +// Returns: +// - error: Database operation error +// +// Example: +// +// err := dao.DeleteByID("memory123") +func (dao *MemoryDAO) DeleteByID(id string) error { + return DB.Where("id = ?", id).Delete(&model.Memory{}).Error +} + +// GetWithOwnerNameByID retrieves a memory with owner name by ID +// Joins with User table to get owner's nickname +// +// Parameters: +// - id: Memory ID +// +// Returns: +// - *model.MemoryListItem: Memory detail with owner name populated +// - error: Database operation error +// +// Example: +// +// memory, err := dao.GetWithOwnerNameByID("memory123") +func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, error) { + querySQL := ` + SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type, + m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id, + m.permissions, m.description, m.memory_size, m.forgetting_policy, + m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date, + m.update_time, m.update_date, + u.nickname as owner_name + FROM memory m + LEFT JOIN user u ON m.tenant_id = u.id + WHERE m.id = ? + ` + + var rawResult struct { + model.Memory + OwnerName *string `gorm:"column:owner_name"` + } + + if err := DB.Raw(querySQL, id).Scan(&rawResult).Error; err != nil { + return nil, err + } + + return &model.MemoryListItem{ + Memory: rawResult.Memory, + OwnerName: rawResult.OwnerName, + }, nil +} + +// GetByFilter retrieves memories with optional filters +// Supports filtering by tenant_id, memory_type, storage_type, and keywords +// Returns paginated results with owner_name from user table JOIN +// +// Parameters: +// - tenantIDs: Array of tenant IDs to filter by (empty means all tenants) +// - memoryTypes: Array of memory type names to filter by (empty means all types) +// - storageType: Storage type to filter by (empty means all types) +// - keywords: Keywords to search in memory names (empty means no keyword filter) +// - page: Page number (1-based) +// - pageSize: Number of items per page +// +// Returns: +// - []*model.MemoryListItem: Memory list items with owner name populated +// - int64: Total count of matching memories +// - error: Database operation error +// +// Example: +// +// memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10) +func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*model.MemoryListItem, int64, error) { + var conditions []string + var args []interface{} + + if len(tenantIDs) > 0 { + conditions = append(conditions, "m.tenant_id IN ?") + args = append(args, tenantIDs) + } + + if len(memoryTypes) > 0 { + memoryTypeInt := CalculateMemoryType(memoryTypes) + conditions = append(conditions, "m.memory_type & ? > 0") + args = append(args, memoryTypeInt) + } + + if storageType != "" { + conditions = append(conditions, "m.storage_type = ?") + args = append(args, storageType) + } + + if keywords != "" { + conditions = append(conditions, "m.name LIKE ?") + args = append(args, "%"+keywords+"%") + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + countSQL := fmt.Sprintf("SELECT COUNT(*) FROM memory m %s", whereClause) + var total int64 + if err := DB.Raw(countSQL, args...).Scan(&total).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + querySQL := fmt.Sprintf(` + SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type, + m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id, + m.permissions, m.description, m.memory_size, m.forgetting_policy, + m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date, + m.update_time, m.update_date, + u.nickname as owner_name + FROM memory m + LEFT JOIN user u ON m.tenant_id = u.id + %s + ORDER BY m.update_time DESC + LIMIT ? OFFSET ? + `, whereClause) + + queryArgs := append(args, pageSize, offset) + + var rawResults []struct { + model.Memory + OwnerName *string `gorm:"column:owner_name"` + } + + if err := DB.Raw(querySQL, queryArgs...).Scan(&rawResults).Error; err != nil { + return nil, 0, err + } + + memories := make([]*model.MemoryListItem, len(rawResults)) + for i, r := range rawResults { + memories[i] = &model.MemoryListItem{ + Memory: r.Memory, + OwnerName: r.OwnerName, + } + } + + return memories, total, nil +} diff --git a/internal/dao/tenant_llm.go b/internal/dao/tenant_llm.go index fdf0bad69..fab7d1dca 100644 --- a/internal/dao/tenant_llm.go +++ b/internal/dao/tenant_llm.go @@ -141,3 +141,130 @@ func (dao *TenantLLMDAO) DeleteByTenantID(tenantID string) (int64, error) { result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.TenantLLM{}) return result.RowsAffected, result.Error } + +// splitModelNameAndFactory splits model name and factory from combined format +// This matches Python's split_model_name_and_factory logic +// +// Parameters: +// - modelName: The model name which can be in format "ModelName" or "ModelName@Factory" +// +// Returns: +// - string: The model name without factory prefix +// - string: The factory name (empty string if not specified) +// +// Example: +// +// modelName, factory := splitModelNameAndFactory("gpt-4") +// // Returns: "gpt-4", "" +// +// modelName, factory := splitModelNameAndFactory("gpt-4@OpenAI") +// // Returns: "gpt-4", "OpenAI" +func splitModelNameAndFactory(modelName string) (string, string) { + // Split by "@" separator + // Handle cases like "model@factory" or "model@sub@factory" + lastAtIndex := -1 + for i := len(modelName) - 1; i >= 0; i-- { + if modelName[i] == '@' { + lastAtIndex = i + break + } + } + + // No "@" found, return original name + if lastAtIndex == -1 { + return modelName, "" + } + + // Split into model name and potential factory + modelNamePart := modelName[:lastAtIndex] + factory := modelName[lastAtIndex+1:] + + // Validate if factory exists in llm_factories table + // This matches Python's logic of checking against model providers + var factoryCount int64 + DB.Model(&model.LLMFactories{}).Where("name = ?", factory).Count(&factoryCount) + + // If factory doesn't exist in database, treat the whole string as model name + if factoryCount == 0 { + return modelName, "" + } + + return modelNamePart, factory +} + +// GetByTenantIDAndLLMName gets tenant LLM by tenant ID and LLM name +// This is used to resolve tenant_llm_id from llm_id +// It supports both simple model names and factory-prefixed names (e.g., "gpt-4@OpenAI") +// +// Parameters: +// - tenantID: The tenant identifier +// - llmName: The LLM model name (can include factory prefix like "OpenAI@gpt-4") +// +// Returns: +// - *model.TenantLLM: The tenant LLM record +// - error: Error if not found +// +// Example: +// +// // Simple model name +// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4") +// +// // Model name with factory prefix +// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4@OpenAI") +func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string) (*model.TenantLLM, error) { + var tenantLLM model.TenantLLM + + // Split model name and factory from the combined format + modelName, factory := splitModelNameAndFactory(llmName) + + // First attempt: try to find with model name only + err := DB.Where("tenant_id = ? AND llm_name = ?", tenantID, modelName).First(&tenantLLM).Error + if err == nil { + return &tenantLLM, nil + } + + // Second attempt: if factory is specified, try with both model name and factory + if factory != "" { + err = DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, modelName, factory).First(&tenantLLM).Error + if err == nil { + return &tenantLLM, nil + } + + // Special handling for LocalAI and HuggingFace (matching Python logic) + // These factories append "___FactoryName" to the model name + if factory == "LocalAI" || factory == "HuggingFace" || factory == "OpenAI-API-Compatible" { + specialModelName := modelName + "___" + factory + err = DB.Where("tenant_id = ? AND llm_name = ?", tenantID, specialModelName).First(&tenantLLM).Error + if err == nil { + return &tenantLLM, nil + } + } + } + + // Return the last error (record not found) + return nil, err +} + +// GetByTenantIDLLMNameAndFactory gets tenant LLM by tenant ID, LLM name and factory +// This is used when model name includes factory suffix (e.g., "model@factory") +// +// Parameters: +// - tenantID: The tenant identifier +// - llmName: The LLM model name +// - factory: The LLM factory name +// +// Returns: +// - *model.TenantLLM: The tenant LLM record +// - error: Error if not found +// +// Example: +// +// tenantLLM, err := dao.GetByTenantIDLLMNameAndFactory("tenant123", "gpt-4", "OpenAI") +func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, factory string) (*model.TenantLLM, error) { + var tenantLLM model.TenantLLM + err := DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, llmName, factory).First(&tenantLLM).Error + if err != nil { + return nil, err + } + return &tenantLLM, nil +} diff --git a/internal/handler/memory.go b/internal/handler/memory.go new file mode 100644 index 000000000..a1d9e5423 --- /dev/null +++ b/internal/handler/memory.go @@ -0,0 +1,687 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package handler contains all HTTP request handlers +// This file implements Memory-related API endpoint handlers +// Each method corresponds to an API endpoint in the Python memory_api.py +package handler + +import ( + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" + "ragflow/internal/service" +) + +// MemoryHandler handles Memory-related HTTP requests +// Responsible for processing all Memory-related HTTP requests +// Each method corresponds to an API endpoint, implementing the same logic as Python memory_api.py +type MemoryHandler struct { + memoryService *service.MemoryService // Reference to Memory business service layer +} + +// NewMemoryHandler creates a new MemoryHandler instance +// +// Parameters: +// - memoryService: Pointer to MemoryService business service layer +// +// Returns: +// - *MemoryHandler: Initialized handler instance +func NewMemoryHandler(memoryService *service.MemoryService) *MemoryHandler { + return &MemoryHandler{ + memoryService: memoryService, + } +} + +// CreateMemory handles POST request for creating Memory +// API Path: POST /api/v1/memories +// +// Function: +// - Creates a new memory record +// - Supports automatic system_prompt generation +// - Supports name deduplication (if name exists, adds sequence number) +// +// Request Parameters (JSON Body): +// - name (required): Memory name, max 128 characters +// - memory_type (required): Memory type array, supports ["raw", "semantic", "episodic", "procedural"] +// - embd_id (required): Embedding model ID +// - llm_id (required): LLM model ID +// - tenant_embd_id (optional): Tenant embedding model ID +// - tenant_llm_id (optional): Tenant LLM model ID +// +// Response Format: +// - code: Status code (0=success, other=error) +// - message: true on success, error message on failure +// - data: Memory object on success +// +// Business Logic (matching Python create_memory): +// 1. Validate user login status +// 2. Parse and validate request parameters +// 3. Call service layer to create memory +// 4. Return creation result +func (h *MemoryHandler) CreateMemory(c *gin.Context) { + // Check if API timing is enabled + // If RAGFLOW_API_TIMING environment variable is set, request processing time will be logged + timingEnabled := os.Getenv("RAGFLOW_API_TIMING") + var tStart time.Time + if timingEnabled != "" { + tStart = time.Now() + } + + // Get current logged-in user information + // GetUser is a context value set by the authentication middleware + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + userID := user.ID + + // Parse JSON request body + var req service.CreateMemoryRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": nil, + }) + return + } + + // Validate required field: name + if req.Name == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "name is required", + "data": nil, + }) + return + } + + // Validate required field: memory_type (must be non-empty array) + if len(req.MemoryType) == 0 { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "memory_type is required and must be a list", + "data": nil, + }) + return + } + + // Validate required field: embd_id + if req.EmbdID == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "embd_id is required", + "data": nil, + }) + return + } + + // Validate required field: llm_id + if req.LLMID == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "llm_id is required", + "data": nil, + }) + return + } + + // Record request parsing completion time (for timing) + tParsed := time.Now() + + // Call service layer to create memory + result, err := h.memoryService.CreateMemory(userID, &req) + if err != nil { + // Log error if timing is enabled + if timingEnabled != "" { + totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0 + parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0 + _ = parseMs + _ = totalMs + } + + errMsg := err.Error() + // Determine if it's an argument error and return appropriate error code + if isArgumentError(errMsg) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": errMsg, + "data": nil, + }) + return + } + + // Other errors return server error + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": errMsg, + "data": nil, + }) + return + } + + // Log success if timing is enabled + if timingEnabled != "" { + totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0 + parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0 + validateAndDbMs := totalMs - parseMs + _ = parseMs + _ = validateAndDbMs + _ = totalMs + } + + // Return success response + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": true, + "data": result, + }) +} + +// UpdateMemory handles PUT request for updating Memory +// API Path: PUT /api/v1/memories/:memory_id +// +// Function: +// - Updates configuration information for the specified memory +// - Supports partial updates: only update passed fields +// +// Request Parameters (JSON Body): +// - name (optional): Memory name +// - permissions (optional): Permission setting ["me", "team", "all"] +// - llm_id (optional): LLM model ID +// - embd_id (optional): Embedding model ID +// - tenant_llm_id (optional): Tenant LLM model ID +// - tenant_embd_id (optional): Tenant embedding model ID +// - memory_type (optional): Memory type array +// - memory_size (optional): Memory size, range (0, 5242880] +// - forgetting_policy (optional): Forgetting policy, default "FIFO" +// - temperature (optional): Temperature parameter, range [0, 1] +// - avatar (optional): Avatar URL +// - description (optional): Description +// - system_prompt (optional): System prompt +// - user_prompt (optional): User prompt +// +// Business Rules: +// - name length <= 128 characters +// - Cannot update tenant_embd_id, embd_id, memory_type when memory_size > 0 +// - When updating memory_type, system_prompt is automatically regenerated if it's the default +func (h *MemoryHandler) UpdateMemory(c *gin.Context) { + // Get memory_id from URL path + memoryID := c.Param("memory_id") + if memoryID == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "memory_id is required", + "data": nil, + }) + return + } + + // Get current logged-in user information + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + userID := user.ID + + // Parse JSON request body + var req service.UpdateMemoryRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": nil, + }) + return + } + + // Call service layer to update memory + result, err := h.memoryService.UpdateMemory(userID, memoryID, &req) + if err != nil { + errMsg := err.Error() + // Check if it's a "not found" error + if strings.Contains(errMsg, "not found") { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": errMsg, + "data": nil, + }) + return + } + + // Check if it's an argument error + if isArgumentError(errMsg) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": errMsg, + "data": nil, + }) + return + } + + // Other errors return server error + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": errMsg, + "data": nil, + }) + return + } + + // Return success response + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": true, + "data": result, + }) +} + +// DeleteMemory handles DELETE request for deleting Memory +// API Path: DELETE /api/v1/memories/:memory_id +// +// Function: +// - Deletes the specified memory record +// - Also deletes associated message data +// +// Business Logic: +// 1. Check if memory exists +// 2. Delete memory record +// 3. Delete associated message index +func (h *MemoryHandler) DeleteMemory(c *gin.Context) { + // Get memory_id from URL path + memoryID := c.Param("memory_id") + if memoryID == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "memory_id is required", + "data": nil, + }) + return + } + + // Call service layer to delete memory + err := h.memoryService.DeleteMemory(memoryID) + if err != nil { + errMsg := err.Error() + // Check if it's a "not found" error + if strings.Contains(errMsg, "not found") { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": errMsg, + "data": nil, + }) + return + } + + // Other errors return server error + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": errMsg, + "data": nil, + }) + return + } + + // Return success response + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": true, + "data": nil, + }) +} + +// ListMemories handles GET request for listing Memories +// API Path: GET /api/v1/memories +// +// Function: +// - Lists memories accessible to the current user +// - Supports multiple filter conditions +// - Supports pagination and keyword search +// +// Query Parameters: +// - memory_type (optional): Memory type filter, supports comma-separated multiple types +// - tenant_id (optional): Tenant ID filter +// - storage_type (optional): Storage type filter +// - keywords (optional): Keyword search (fuzzy match on name) +// - page (optional): Page number, default 1 +// - page_size (optional): Items per page, default 50 +// +// Response Format: +// - code: Status code +// - message: true +// - data.memory_list: Array of Memory objects +// - data.total_count: Total record count +func (h *MemoryHandler) ListMemories(c *gin.Context) { + // Get current logged-in user information + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + // Parse query parameters + memoryTypesParam := c.Query("memory_type") + tenantIDsParam := c.Query("tenant_id") + storageType := c.Query("storage_type") + keywords := c.Query("keywords") + pageStr := c.DefaultQuery("page", "1") + pageSizeStr := c.DefaultQuery("page_size", "50") + + // Convert pagination parameters to integers + page, _ := strconv.Atoi(pageStr) + pageSize, _ := strconv.Atoi(pageSizeStr) + + // Validate pagination parameters + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 50 + } + + // Parse memory_type parameter (supports comma separation) + var memoryTypes []string + if memoryTypesParam != "" { + if strings.Contains(memoryTypesParam, ",") { + memoryTypes = strings.Split(memoryTypesParam, ",") + } else { + memoryTypes = []string{memoryTypesParam} + } + } + + // Parse tenant_id parameter + // If not specified, service will get all tenants associated with the user + var tenantIDs []string + if tenantIDsParam != "" { + if strings.Contains(tenantIDsParam, ",") { + tenantIDs = strings.Split(tenantIDsParam, ",") + } else { + tenantIDs = []string{tenantIDsParam} + } + } + + // Call service layer to get memory list + result, err := h.memoryService.ListMemories(user.ID, tenantIDs, memoryTypes, storageType, keywords, page, pageSize) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": err.Error(), + "data": nil, + }) + return + } + + // Return success response + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": true, + "data": result, + }) +} + +// GetMemoryConfig handles GET request for getting Memory configuration +// API Path: GET /api/v1/memories/:memory_id/config +// +// Function: +// - Gets complete configuration information for the specified memory +// - Includes owner name (obtained via JOIN with user table) +// +// Response Format: +// - code: Status code +// - message: true +// - data: Memory object, including owner_name field +func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) { + // Get memory_id from URL path + memoryID := c.Param("memory_id") + if memoryID == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "memory_id is required", + "data": nil, + }) + return + } + + // Call service layer to get memory configuration + result, err := h.memoryService.GetMemoryConfig(memoryID) + if err != nil { + errMsg := err.Error() + // Check if it's a "not found" error + if strings.Contains(errMsg, "not found") { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": errMsg, + "data": nil, + }) + return + } + + // Other errors return server error + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": errMsg, + "data": nil, + }) + return + } + + // Return success response + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": true, + "data": result, + }) +} + +// GetMemoryMessages handles GET request for getting Memory messages +// API Path: GET /api/v1/memories/:memory_id +// +// Function: +// - Gets message list associated with the specified memory +// - Supports filtering by agent_id +// - Supports keyword search and pagination +// +// Query Parameters: +// - agent_id (optional): Agent ID filter, supports comma-separated multiple +// - keywords (optional): Keyword search +// - page (optional): Page number, default 1 +// - page_size (optional): Items per page, default 50 +// +// Response Format: +// - code: Status code +// - message: true +// - data.messages: Array of message objects +// - data.storage_type: Storage type +// +// TODO: Implementation pending - depends on CanvasService and TaskService +func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies", + "data": nil, + }) +} + +// AddMessage handles POST request for adding messages +// API Path: POST /api/v1/messages +// +// Function: +// - Adds messages to one or more memories +// - Messages will be embedded and saved to vector database +// - Creates asynchronous task for processing +// +// Request Parameters (JSON Body): +// - memory_id (required): Memory ID or ID array +// - agent_id (required): Agent ID +// - session_id (required): Session ID +// - user_input (required): User input +// - agent_response (required): Agent response +// - user_id (optional): User ID +// +// TODO: Implementation pending - depends on embedding engine +func (h *MemoryHandler) AddMessage(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "AddMessage not implemented - pending embedding engine dependency", + "data": nil, + }) +} + +// ForgetMessage handles DELETE request for forgetting messages +// API Path: DELETE /api/v1/messages/:memory_id/:message_id +// +// Function: +// - Soft-deletes the specified message (sets forget_at timestamp) +// - Message is not immediately deleted from database, but marked as "forgotten" +// +// Parameter Format: +// - memory_id: Memory ID +// - message_id: Message ID (integer) +// +// TODO: Implementation pending - depends on embedding engine +func (h *MemoryHandler) ForgetMessage(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "ForgetMessage not implemented - pending embedding engine dependency", + "data": nil, + }) +} + +// UpdateMessage handles PUT request for updating message status +// API Path: PUT /api/v1/messages/:memory_id/:message_id +// +// Function: +// - Updates status of the specified message +// - status is a boolean, converted to integer for storage (true=1, false=0) +// +// Request Parameters (JSON Body): +// - status (required): Message status, boolean +// +// TODO: Implementation pending - depends on embedding engine +func (h *MemoryHandler) UpdateMessage(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "UpdateMessage not implemented - pending embedding engine dependency", + "data": nil, + }) +} + +// SearchMessage handles GET request for searching messages +// API Path: GET /api/v1/messages/search +// +// Function: +// - Searches messages across multiple memories +// - Supports vector similarity search and keyword search +// - Fuses results from both search methods +// +// Query Parameters: +// - memory_id (optional): Memory ID list, supports comma separation +// - query (optional): Search query text +// - similarity_threshold (optional): Similarity threshold, default 0.2 +// - keywords_similarity_weight (optional): Keyword weight, default 0.7 +// - top_n (optional): Number of results to return, default 5 +// - agent_id (optional): Agent ID filter +// - session_id (optional): Session ID filter +// - user_id (optional): User ID filter +// +// TODO: Implementation pending - depends on embedding engine +func (h *MemoryHandler) SearchMessage(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "SearchMessage not implemented - pending embedding engine dependency", + "data": nil, + }) +} + +// GetMessages handles GET request for getting message list +// API Path: GET /api/v1/messages +// +// Function: +// - Gets recent messages from specified memories +// - Supports filtering by agent_id and session_id +// +// Query Parameters: +// - memory_id (required): Memory ID list, supports comma separation +// - agent_id (optional): Agent ID filter +// - session_id (optional): Session ID filter +// - limit (optional): Number of results to return, default 10 +// +// TODO: Implementation pending - depends on embedding engine +func (h *MemoryHandler) GetMessages(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "GetMessages not implemented - pending embedding engine dependency", + "data": nil, + }) +} + +// GetMessageContent handles GET request for getting message content +// API Path: GET /api/v1/messages/:memory_id/:message_id/content +// +// Function: +// - Gets complete content of the specified message +// - doc_id format: memory_id + "_" + message_id +// +// Parameter Format: +// - memory_id: Memory ID +// - message_id: Message ID (integer) +// +// TODO: Implementation pending - depends on embedding engine +func (h *MemoryHandler) GetMessageContent(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "GetMessageContent not implemented - pending embedding engine dependency", + "data": nil, + }) +} + +// isArgumentError determines if an error message is an argument error +// +// Function: +// - Checks if the error message contains any argument validation-related prefixes +// - Used to distinguish argument errors from server errors +// +// Parameters: +// - msg: Error message string +// +// Returns: +// - bool: true if it's an argument error, false otherwise +func isArgumentError(msg string) bool { + // Define list of argument error prefixes + // Matches Python ArgumentException error messages + argumentErrorPrefixes := []string{ + "memory name cannot be empty", // Memory name cannot be empty + "memory name exceeds limit", // Memory name exceeds limit + "memory type must be a list", // memory_type must be a list + "memory type is not supported", // Unsupported memory_type + } + // Check if error message starts with any prefix + for _, prefix := range argumentErrorPrefixes { + if len(msg) >= len(prefix) && msg[:len(prefix)] == prefix { + return true + } + } + return false +} diff --git a/internal/model/memory.go b/internal/model/memory.go index 9e6480ad9..4982aa0f5 100644 --- a/internal/model/memory.go +++ b/internal/model/memory.go @@ -42,3 +42,12 @@ type Memory struct { func (Memory) TableName() string { return "memory" } + +// MemoryListItem represents a memory record with owner name from JOIN query. +// Uses struct embedding to extend Memory struct with owner_name from user table JOIN. +// Note: MemoryType is kept as int64 from Memory embedding; conversion to []string +// happens in the Service layer via CreateMemoryResponse. +type MemoryListItem struct { + Memory + OwnerName *string `json:"owner_name,omitempty"` +} diff --git a/internal/router/router.go b/internal/router/router.go index 6f6f3721a..44f783e06 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -38,6 +38,7 @@ type Router struct { connectorHandler *handler.ConnectorHandler searchHandler *handler.SearchHandler fileHandler *handler.FileHandler + memoryHandler *handler.MemoryHandler } // NewRouter create router @@ -56,6 +57,7 @@ func NewRouter( connectorHandler *handler.ConnectorHandler, searchHandler *handler.SearchHandler, fileHandler *handler.FileHandler, + memoryHandler *handler.MemoryHandler, ) *Router { return &Router{ authHandler: authHandler, @@ -72,6 +74,7 @@ func NewRouter( connectorHandler: connectorHandler, searchHandler: searchHandler, fileHandler: fileHandler, + memoryHandler: memoryHandler, } } @@ -163,6 +166,28 @@ func (r *Router) Setup(engine *gin.Engine) { { authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID) } + + // Memory routes + memory := v1.Group("/memories") + { + memory.POST("", r.memoryHandler.CreateMemory) + memory.PUT("/:memory_id", r.memoryHandler.UpdateMemory) + memory.DELETE("/:memory_id", r.memoryHandler.DeleteMemory) + memory.GET("", r.memoryHandler.ListMemories) + memory.GET("/:memory_id/config", r.memoryHandler.GetMemoryConfig) + memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages) + } + + // TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine + // message := v1.Group("/messages") + // { + // message.POST("", r.memoryHandler.AddMessage) + // message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage) + // message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage) + // message.GET("/search", r.memoryHandler.SearchMessage) + // message.GET("", r.memoryHandler.GetMessages) + // message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent) + // } } // Knowledge base routes @@ -260,6 +285,7 @@ func (r *Router) Setup(engine *gin.Engine) { file.GET("/parent_folder", r.fileHandler.GetParentFolder) file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders) } + } // Handle undefined routes diff --git a/internal/service/memory.go b/internal/service/memory.go new file mode 100644 index 000000000..c89f26a37 --- /dev/null +++ b/internal/service/memory.go @@ -0,0 +1,892 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "errors" + "fmt" + "path" + "regexp" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + + "ragflow/internal/dao" + "ragflow/internal/model" +) + +const ( + // MemoryNameLimit is the maximum length allowed for memory names + MemoryNameLimit = 128 + // MemorySizeLimit is the maximum memory size in bytes (5MB) + MemorySizeLimit = 5242880 +) + +// Note: MemoryType, MemoryTypeRaw, MemoryTypeSemantic, MemoryTypeEpisodic, +// MemoryTypeProcedural, and CalculateMemoryType are defined in the dao package +// and imported as dao.MemoryType, dao.MemoryTypeRaw, etc. + +// TenantPermission defines the access permission levels for memory resources +// Note: This type is specific to the service layer +type TenantPermission string + +const ( + // TenantPermissionMe restricts access to the owner only + TenantPermissionMe TenantPermission = "me" + // TenantPermissionTeam allows access within the same team + TenantPermissionTeam TenantPermission = "team" + // TenantPermissionAll allows access to all tenants + TenantPermissionAll TenantPermission = "all" +) + +// validPermissions defines which permission values are valid +var validPermissions = map[TenantPermission]bool{ + TenantPermissionMe: true, + TenantPermissionTeam: true, + TenantPermissionAll: true, +} + +// ForgettingPolicy defines the strategy for forgetting old memory entries +type ForgettingPolicy string + +const ( + // ForgettingPolicyFIFO uses First-In-First-Out strategy for forgetting + ForgettingPolicyFIFO ForgettingPolicy = "FIFO" +) + +// validForgettingPolicies defines which forgetting policies are valid +var validForgettingPolicies = map[ForgettingPolicy]bool{ + ForgettingPolicyFIFO: true, +} + +// +// Note: CalculateMemoryType and GetMemoryTypeHuman functions have been moved to dao package +// Use dao.CalculateMemoryType() and dao.GetMemoryTypeHuman() instead + +// PromptAssembler handles the assembly of system prompts for memory extraction +type PromptAssembler struct{} + +// SYSTEM_BASE_TEMPLATE is the base template for the system prompt used in memory extraction +// It includes placeholders for type-specific instructions, timestamp format, and max items +var SYSTEM_BASE_TEMPLATE = `**Memory Extraction Specialist** +You are an expert at analyzing conversations to extract structured memory. + +{type_specific_instructions} + + +**OUTPUT REQUIREMENTS:** +1. Output MUST be valid JSON +2. Follow the specified output format exactly +3. Each extracted item MUST have: content, valid_at, invalid_at +4. Timestamps in {timestamp_format} format +5. Only extract memory types specified above +6. Maximum {max_items} items per type +` + +// TYPE_INSTRUCTIONS contains specific instructions for each memory type extraction +var TYPE_INSTRUCTIONS = map[string]string{ + "semantic": ` +**EXTRACT SEMANTIC KNOWLEDGE:** +- Universal facts, definitions, concepts, relationships +- Time-invariant, generally true information + +**Timestamp Rules:** +- valid_at: When the fact became true +- invalid_at: When it becomes false or empty if still true +`, + "episodic": ` +**EXTRACT EPISODIC KNOWLEDGE:** +- Specific experiences, events, personal stories +- Time-bound, person-specific, contextual + +**Timestamp Rules:** +- valid_at: Event start/occurrence time +- invalid_at: Event end time or empty if instantaneous +`, + "procedural": ` +**EXTRACT PROCEDURAL KNOWLEDGE:** +- Processes, methods, step-by-step instructions +- Goal-oriented, actionable, often includes conditions + +**Timestamp Rules:** +- valid_at: When procedure becomes valid/effective +- invalid_at: When it expires/becomes obsolete or empty if current +`, +} + +// OUTPUT_TEMPLATES defines the output format for each memory type +var OUTPUT_TEMPLATES = map[string]string{ + "semantic": `"semantic": [{"content": "Clear factual statement", "valid_at": "timestamp or empty", "invalid_at": "timestamp or empty"}]`, + "episodic": `"episodic": [{"content": "Narrative event description", "valid_at": "event start timestamp", "invalid_at": "event end timestamp or empty"}]`, + "procedural": `"procedural": [{"content": "Actionable instructions", "valid_at": "procedure effective timestamp", "invalid_at": "procedure expiration timestamp or empty"}]`, +} + +// AssembleSystemPrompt generates a complete system prompt for memory extraction +// +// Parameters: +// - memoryTypes: Array of memory type names to extract (e.g., ["semantic", "episodic"]) +// +// Returns: +// - string: Complete system prompt with type-specific instructions and output format +// +// Example: +// +// AssembleSystemPrompt([]string{"semantic", "episodic"}) returns a prompt with instructions +// for both semantic and episodic memory extraction +func (PromptAssembler) AssembleSystemPrompt(memoryTypes []string) string { + typesToExtract := getTypesToExtract(memoryTypes) + if len(typesToExtract) == 0 { + typesToExtract = []string{"raw"} + } + + typeInstructions := generateTypeInstructions(typesToExtract) + outputFormat := generateOutputFormat(typesToExtract) + + fullPrompt := strings.Replace(SYSTEM_BASE_TEMPLATE, "{type_specific_instructions}", typeInstructions, 1) + fullPrompt = strings.Replace(fullPrompt, "{timestamp_format}", "ISO 8601", 1) + fullPrompt = strings.Replace(fullPrompt, "{max_items}", "5", 1) + + fullPrompt += fmt.Sprintf("\n**REQUIRED OUTPUT FORMAT (JSON):\n```json\n{\n%s\n}\n```\n", outputFormat) + + return fullPrompt +} + +// getTypesToExtract filters out "raw" type and returns valid memory types +// +// Parameters: +// - requestedTypes: Array of requested memory type names +// +// Returns: +// - []string: Filtered array of memory type names (excluding "raw") +func getTypesToExtract(requestedTypes []string) []string { + types := make(map[string]bool) + for _, rt := range requestedTypes { + lowerRT := strings.ToLower(rt) + if lowerRT != "raw" { + if _, ok := dao.MemoryTypeMap[lowerRT]; ok { + types[lowerRT] = true + } + } + } + result := make([]string, 0, len(types)) + for t := range types { + result = append(result, t) + } + return result +} + +// generateTypeInstructions concatenates type-specific instructions +// +// Parameters: +// - typesToExtract: Array of memory type names +// +// Returns: +// - string: Concatenated instructions for all specified types +func generateTypeInstructions(typesToExtract []string) string { + var instructions []string + for _, mt := range typesToExtract { + if instr, ok := TYPE_INSTRUCTIONS[mt]; ok { + instructions = append(instructions, instr) + } + } + return strings.Join(instructions, "\n") +} + +// generateOutputFormat concatenates output format templates +// +// Parameters: +// - typesToExtract: Array of memory type names +// +// Returns: +// - string: Concatenated output format templates +func generateOutputFormat(typesToExtract []string) string { + var outputParts []string + for _, mt := range typesToExtract { + if tmpl, ok := OUTPUT_TEMPLATES[mt]; ok { + outputParts = append(outputParts, tmpl) + } + } + return strings.Join(outputParts, ",\n") +} + +// MemoryService handles business logic for memory operations +// It provides methods for creating, updating, deleting, and querying memories +type MemoryService struct { + memoryDAO *dao.MemoryDAO +} + +// NewMemoryService creates a new MemoryService instance +// +// Returns: +// - *MemoryService: Initialized service instance with DAO +func NewMemoryService() *MemoryService { + return &MemoryService{ + memoryDAO: dao.NewMemoryDAO(), + } +} + +// splitNameCounter splits a filename into base name and counter +// Handles names in format "filename(123)" pattern +// +// Parameters: +// - filename: The filename to split +// +// Returns: +// - string: The base name without counter +// - *int: The counter value, or nil if no counter exists +// +// Example: +// +// splitNameCounter("test(5)") returns ("test", 5) +// splitNameCounter("test") returns ("test", nil) +func splitNameCounter(filename string) (string, *int) { + re := regexp.MustCompile(`^(.+)\((\d+)\)$`) + matches := re.FindStringSubmatch(filename) + if len(matches) >= 3 { + counter := -1 + fmt.Sscanf(matches[2], "%d", &counter) + stem := strings.TrimRight(matches[1], " ") + return stem, &counter + } + return filename, nil +} + +// duplicateName generates a unique name by appending a counter if the name already exists +// It tries up to 1000 times to generate a unique name +// +// Parameters: +// - queryFunc: Function to check if a name already exists (returns true if exists) +// - name: The original name +// - tenantID: The tenant ID for name uniqueness check +// +// Returns: +// - string: A unique name (either original or with counter appended) +// +// Example: +// +// duplicateName(func(name string, tid string) bool { return false }, "test", "tenant1") returns "test" +// duplicateName(func(name string, tid string) bool { return true }, "test", "tenant1") returns "test(1)" +func duplicateName(queryFunc func(name string, tenantID string) bool, name string, tenantID string) string { + const maxRetries = 1000 + + originalName := name + currentName := name + retries := 0 + + for retries < maxRetries { + if !queryFunc(currentName, tenantID) { + return currentName + } + + stem, counter := splitNameCounter(currentName) + ext := path.Ext(stem) + stemBase := strings.TrimSuffix(stem, ext) + + newCounter := 1 + if counter != nil { + newCounter = *counter + 1 + } + + currentName = fmt.Sprintf("%s(%d)%s", stemBase, newCounter, ext) + retries++ + } + + panic(fmt.Sprintf("Failed to generate unique name within %d attempts. Original: %s", maxRetries, originalName)) +} + +// CreateMemoryRequest defines the request structure for creating a memory +type CreateMemoryRequest struct { + // Name is the memory name (required, max 128 characters) + Name string `json:"name" binding:"required"` + // MemoryType is the array of memory type names (required) + MemoryType []string `json:"memory_type" binding:"required"` + // EmbdID is the embedding model ID (required) + EmbdID string `json:"embd_id" binding:"required"` + // LLMID is the language model ID (required) + LLMID string `json:"llm_id" binding:"required"` + // TenantEmbdID is the tenant-specific embedding model ID (optional) + TenantEmbdID *string `json:"tenant_embd_id"` + // TenantLLMID is the tenant-specific language model ID (optional) + TenantLLMID *string `json:"tenant_llm_id"` +} + +// UpdateMemoryRequest defines the request structure for updating a memory +// All fields are optional, only provided fields will be updated +type UpdateMemoryRequest struct { + // Name is the new memory name (optional) + Name *string `json:"name"` + // Permissions is the new permission level (optional) + Permissions *string `json:"permissions"` + // LLMID is the new language model ID (optional) + LLMID *string `json:"llm_id"` + // EmbdID is the new embedding model ID (optional) + EmbdID *string `json:"embd_id"` + // TenantLLMID is the new tenant-specific language model ID (optional) + TenantLLMID *string `json:"tenant_llm_id"` + // TenantEmbdID is the new tenant-specific embedding model ID (optional) + TenantEmbdID *string `json:"tenant_embd_id"` + // MemoryType is the new array of memory type names (optional) + MemoryType []string `json:"memory_type"` + // MemorySize is the new memory size in bytes (optional, max 5MB) + MemorySize *int64 `json:"memory_size"` + // ForgettingPolicy is the new forgetting policy (optional) + ForgettingPolicy *string `json:"forgetting_policy"` + // Temperature is the new temperature value (optional, range [0, 1]) + Temperature *float64 `json:"temperature"` + // Avatar is the new avatar URL (optional) + Avatar *string `json:"avatar"` + // Description is the new description (optional) + Description *string `json:"description"` + // SystemPrompt is the new system prompt (optional) + SystemPrompt *string `json:"system_prompt"` + // UserPrompt is the new user prompt (optional) + UserPrompt *string `json:"user_prompt"` +} + +// CreateMemoryResponse defines the response structure for memory operations +// Uses struct embedding to extend Memory struct with API-specific fields +type CreateMemoryResponse struct { + model.Memory + OwnerName *string `json:"owner_name,omitempty"` + MemoryType []string `json:"memory_type"` +} + +// ListMemoryResponse defines the response structure for listing memories +type ListMemoryResponse struct { + // MemoryList is the array of memory objects + MemoryList []map[string]interface{} `json:"memory_list"` + // TotalCount is the total number of memories + TotalCount int64 `json:"total_count"` +} + +// CreateMemory creates a new memory with the given parameters +// It validates the request, generates a unique name if needed, and creates the memory record +// +// Parameters: +// - tenantID: The tenant ID for which to create the memory +// - req: The memory creation request containing name, memory_type, embd_id, llm_id, etc. +// +// Returns: +// - *CreateMemoryResponse: The created memory details +// - error: Error if validation fails or creation fails +// +// Example: +// +// req := &CreateMemoryRequest{Name: "MyMemory", MemoryType: []string{"semantic"}, EmbdID: "embd1", LLMID: "llm1"} +// resp, err := service.CreateMemory("tenant123", req) +func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) (*CreateMemoryResponse, error) { + // Ensure tenant model IDs are populated for LLM and embedding model parameters + // This automatically fills tenant_llm_id and tenant_embd_id based on llm_id and embd_id + tenantLLMService := NewTenantLLMService() + params := map[string]interface{}{ + "llm_id": req.LLMID, + "embd_id": req.EmbdID, + } + params = tenantLLMService.EnsureTenantModelIDForParams(tenantID, params) + + // Update request with tenant model IDs from the processed params + if tenantLLMID, ok := params["tenant_llm_id"].(int64); ok { + tenantLLMIDStr := strconv.FormatInt(tenantLLMID, 10) + req.TenantLLMID = &tenantLLMIDStr + } + if tenantEmbdID, ok := params["tenant_embd_id"].(int64); ok { + tenantEmbdIDStr := strconv.FormatInt(tenantEmbdID, 10) + req.TenantEmbdID = &tenantEmbdIDStr + } + + memoryName := strings.TrimSpace(req.Name) + if len(memoryName) == 0 { + return nil, errors.New("memory name cannot be empty or whitespace") + } + if len(memoryName) > MemoryNameLimit { + return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit) + } + + if !isList(req.MemoryType) { + return nil, errors.New("memory type must be a list") + } + + memoryTypeSet := make(map[string]bool) + for _, mt := range req.MemoryType { + lowerMT := strings.ToLower(mt) + if _, ok := dao.MemoryTypeMap[lowerMT]; !ok { + return nil, fmt.Errorf("memory type '%s' is not supported", mt) + } + memoryTypeSet[lowerMT] = true + } + uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet)) + for mt := range memoryTypeSet { + uniqueMemoryTypes = append(uniqueMemoryTypes, mt) + } + + memoryName = duplicateName(func(name string, tid string) bool { + existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid) + return len(existing) > 0 + }, memoryName, tenantID) + + if len(memoryName) > MemoryNameLimit { + return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit) + } + + memoryTypeInt := dao.CalculateMemoryType(uniqueMemoryTypes) + timestamp := time.Now().UnixMilli() + + systemPrompt := PromptAssembler{}.AssembleSystemPrompt(uniqueMemoryTypes) + + newID := strings.ReplaceAll(uuid.New().String(), "-", "") + if len(newID) > 32 { + newID = newID[:32] + } + + memory := &model.Memory{ + ID: newID, + Name: memoryName, + TenantID: tenantID, + MemoryType: memoryTypeInt, + StorageType: "table", + EmbdID: req.EmbdID, + LLMID: req.LLMID, + Permissions: "me", + MemorySize: MemorySizeLimit, + ForgettingPolicy: string(ForgettingPolicyFIFO), + Temperature: 0.5, + SystemPrompt: &systemPrompt, + } + + // Convert tenant model IDs from string to int64 for database + if req.TenantEmbdID != nil { + if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil { + memory.TenantEmbdID = &embdID + } + } + if req.TenantLLMID != nil { + if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil { + memory.TenantLLMID = &llmID + } + } + memory.CreateTime = ×tamp + memory.UpdateTime = ×tamp + + if err := s.memoryDAO.Create(memory); err != nil { + return nil, errors.New("could not create new memory") + } + + createdMemory, err := s.memoryDAO.GetByID(newID) + if err != nil { + return nil, errors.New("could not create new memory") + } + + return formatRetDataFromMemory(createdMemory), nil +} + +// UpdateMemory updates an existing memory with the provided fields +// Only the fields specified in the request will be updated (partial update) +// +// Parameters: +// - tenantID: The tenant ID for ownership verification +// - memoryID: The ID of the memory to update +// - req: The update request with optional fields to update +// +// Returns: +// - *CreateMemoryResponse: The updated memory details +// - error: Error if validation fails or update fails +// +// Example: +// +// req := &UpdateMemoryRequest{Name: ptr("NewName"), MemorySize: ptr(int64(1000000))} +// resp, err := service.UpdateMemory("tenant123", "memory456", req) +func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *UpdateMemoryRequest) (*CreateMemoryResponse, error) { + updateDict := make(map[string]interface{}) + + if req.Name != nil { + memoryName := strings.TrimSpace(*req.Name) + if len(memoryName) == 0 { + return nil, errors.New("memory name cannot be empty or whitespace") + } + if len(memoryName) > MemoryNameLimit { + return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit) + } + memoryName = duplicateName(func(name string, tid string) bool { + existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid) + return len(existing) > 0 + }, memoryName, tenantID) + if len(memoryName) > MemoryNameLimit { + return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit) + } + updateDict["name"] = memoryName + } + + if req.Permissions != nil { + perm := TenantPermission(strings.ToLower(*req.Permissions)) + if !validPermissions[perm] { + return nil, fmt.Errorf("unknown permission '%s'", *req.Permissions) + } + updateDict["permissions"] = perm + } + + if req.LLMID != nil { + updateDict["llm_id"] = *req.LLMID + } + + if req.EmbdID != nil { + updateDict["embd_id"] = *req.EmbdID + } + + if req.TenantLLMID != nil { + if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil { + updateDict["tenant_llm_id"] = llmID + } + } + + if req.TenantEmbdID != nil { + if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil { + updateDict["tenant_embd_id"] = embdID + } + } + + if req.MemoryType != nil && len(req.MemoryType) > 0 { + memoryTypeSet := make(map[string]bool) + for _, mt := range req.MemoryType { + lowerMT := strings.ToLower(mt) + if _, ok := dao.MemoryTypeMap[lowerMT]; !ok { + return nil, fmt.Errorf("memory type '%s' is not supported", mt) + } + memoryTypeSet[lowerMT] = true + } + uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet)) + for mt := range memoryTypeSet { + uniqueMemoryTypes = append(uniqueMemoryTypes, mt) + } + updateDict["memory_type"] = uniqueMemoryTypes + } + + if req.MemorySize != nil { + memorySize := *req.MemorySize + if !(memorySize > 0 && memorySize <= MemorySizeLimit) { + return nil, fmt.Errorf("memory size should be in range (0, %d] Bytes", MemorySizeLimit) + } + updateDict["memory_size"] = memorySize + } + + if req.ForgettingPolicy != nil { + fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy)) + if !validForgettingPolicies[fp] { + return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy) + } + updateDict["forgetting_policy"] = fp + } + + if req.Temperature != nil { + temp := *req.Temperature + if !(temp >= 0 && temp <= 1) { + return nil, errors.New("temperature should be in range [0, 1]") + } + updateDict["temperature"] = temp + } + + for _, field := range []string{"avatar", "description", "system_prompt", "user_prompt"} { + switch field { + case "avatar": + if req.Avatar != nil { + updateDict["avatar"] = *req.Avatar + } + case "description": + if req.Description != nil { + updateDict["description"] = *req.Description + } + case "system_prompt": + if req.SystemPrompt != nil { + updateDict["system_prompt"] = *req.SystemPrompt + } + case "user_prompt": + if req.UserPrompt != nil { + updateDict["user_prompt"] = *req.UserPrompt + } + } + } + + currentMemory, err := s.memoryDAO.GetByID(memoryID) + if err != nil { + return nil, fmt.Errorf("memory '%s' not found", memoryID) + } + + if len(updateDict) == 0 { + return formatRetDataFromMemory(currentMemory), nil + } + + memorySize := currentMemory.MemorySize + notAllowedUpdate := []string{} + for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} { + if _, ok := updateDict[f]; ok && memorySize > 0 { + notAllowedUpdate = append(notAllowedUpdate, f) + } + } + if len(notAllowedUpdate) > 0 { + return nil, fmt.Errorf("can't update %v when memory isn't empty", notAllowedUpdate) + } + + if _, ok := updateDict["memory_type"]; ok { + if _, ok := updateDict["system_prompt"]; !ok { + memoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType) + if len(memoryTypes) > 0 && currentMemory.SystemPrompt != nil { + defaultPrompt := PromptAssembler{}.AssembleSystemPrompt(memoryTypes) + if *currentMemory.SystemPrompt == defaultPrompt { + if types, ok := updateDict["memory_type"].([]string); ok { + updateDict["system_prompt"] = PromptAssembler{}.AssembleSystemPrompt(types) + } + } + } + } + } + + if err := s.memoryDAO.UpdateByID(memoryID, updateDict); err != nil { + return nil, errors.New("failed to update memory") + } + + updatedMemory, err := s.memoryDAO.GetByID(memoryID) + if err != nil { + return nil, errors.New("failed to get updated memory") + } + + return formatRetDataFromMemory(updatedMemory), nil +} + +// DeleteMemory deletes a memory by ID +// It also deletes associated message indexes before removing the memory record +// +// Parameters: +// - memoryID: The ID of the memory to delete +// +// Returns: +// - error: Error if memory not found or deletion fails +// +// Example: +// +// err := service.DeleteMemory("memory456") +func (s *MemoryService) DeleteMemory(memoryID string) error { + _, err := s.memoryDAO.GetByID(memoryID) + if err != nil { + return fmt.Errorf("memory '%s' not found", memoryID) + } + + // TODO: Delete associated message index - Implementation pending MessageService + // messageService := NewMessageService() + // hasIndex, _ := messageService.HasIndex(memory.TenantID, memoryID) + // if hasIndex { + // messageService.DeleteMessage(nil, memory.TenantID, memoryID) + // } + + // Delete memory record + if err := s.memoryDAO.DeleteByID(memoryID); err != nil { + return errors.New("failed to delete memory") + } + + return nil +} + +// ListMemories retrieves a paginated list of memories with optional filters +// When tenantIDs is empty, it retrieves all tenants associated with the user +// +// Parameters: +// - userID: The user ID for tenant filtering when tenantIDs is empty +// - tenantIDs: Array of tenant IDs to filter by (empty means all user's tenants) +// - memoryTypes: Array of memory type names to filter by (empty means all types) +// - storageType: Storage type to filter by (empty means all types) +// - keywords: Keywords to search in memory names (empty means no keyword filter) +// - page: Page number (1-based) +// - pageSize: Number of items per page +// +// Returns: +// - *ListMemoryResponse: Contains memory list and total count +// - error: Error if query fails +// +// Example: +// +// resp, err := service.ListMemories("user123", []string{}, []string{"semantic"}, "table", "test", 1, 10) +func (s *MemoryService) ListMemories(userID string, tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) (*ListMemoryResponse, error) { + // If tenantIDs is empty, get all tenants associated with the user + if len(tenantIDs) == 0 { + userTenantService := NewUserTenantService() + userTenants, err := userTenantService.GetUserTenantRelationByUserID(userID) + if err != nil { + return nil, fmt.Errorf("failed to get user tenants: %w", err) + } + tenantIDs = make([]string, len(userTenants)) + for i, tenant := range userTenants { + tenantIDs[i] = tenant.TenantID + } + } + + memories, total, err := s.memoryDAO.GetByFilter(tenantIDs, memoryTypes, storageType, keywords, page, pageSize) + if err != nil { + return nil, err + } + + memoryList := make([]map[string]interface{}, 0, len(memories)) + for _, m := range memories { + resp := formatRetDataFromMemoryListItem(m) + var createDateStr *string + if resp.CreateTime != nil { + createDateStr = formatDateToString(*resp.CreateTime) + } + memoryMap := map[string]interface{}{ + "id": resp.ID, + "name": resp.Name, + "avatar": resp.Avatar, + "tenant_id": resp.TenantID, + "owner_name": resp.OwnerName, + "memory_type": resp.MemoryType, + "storage_type": resp.StorageType, + "permissions": resp.Permissions, + "description": resp.Description, + "create_time": resp.CreateTime, + "create_date": createDateStr, + } + memoryList = append(memoryList, memoryMap) + } + + return &ListMemoryResponse{ + MemoryList: memoryList, + TotalCount: total, + }, nil +} + +// GetMemoryConfig retrieves the full configuration of a memory by ID +// +// Parameters: +// - memoryID: The ID of the memory to retrieve +// +// Returns: +// - *CreateMemoryResponse: The memory configuration details +// - error: Error if memory not found +// +// Example: +// +// resp, err := service.GetMemoryConfig("memory456") +func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, error) { + memory, err := s.memoryDAO.GetWithOwnerNameByID(memoryID) + if err != nil { + return nil, fmt.Errorf("memory '%s' not found", memoryID) + } + return formatRetDataFromMemoryListItem(memory), nil +} + +// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService +// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... } + +// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService +// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... } + +// TODO: AddMessage - Implementation pending - depends on embedding engine +// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... } + +// TODO: ForgetMessage - Implementation pending - depends on embedding engine +// func (s *MemoryService) ForgetMessage(memoryID string, messageID int) (bool, error) { ... } + +// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine +// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... } + +// TODO: SearchMessage - Implementation pending - depends on embedding engine +// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... } + +// TODO: GetMessages - Implementation pending - depends on embedding engine +// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... } + +// TODO: GetMessageContent - Implementation pending - depends on embedding engine +// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... } + +// isList checks if a value is a list or array type +// This is a utility function for type validation +// +// Parameters: +// - v: The value to check +// +// Returns: +// - bool: true if v is []interface{} or []string, false otherwise +// +// Example: +// +// isList([]string{"a", "b"}) returns true +// isList("test") returns false +func isList(v interface{}) bool { + switch v.(type) { + case []interface{}, []string: + return true + default: + return false + } +} + +// formatRetDataFromMemory converts a Memory model to CreateMemoryResponse format +// This is a utility function for formatting memory data for API responses +// +// Parameters: +// - memory: The Memory model to format +// +// Returns: +// - *CreateMemoryResponse: Formatted memory response with human-readable types and dates +// +// Example: +// +// resp := formatRetDataFromMemory(memoryModel) +func formatRetDataFromMemory(memory *model.Memory) *CreateMemoryResponse { + memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType) + + resp := &CreateMemoryResponse{ + Memory: *memory, + OwnerName: nil, + MemoryType: memoryTypes, + } + return resp +} + +func formatDateToString(t int64) *string { + if t == 0 { + return nil + } + // Database stores timestamps in milliseconds, convert to seconds + if t > 1e10 { + t = t / 1000 + } + timeObj := time.Unix(t, 0) + s := timeObj.Format("2006-01-02 15:04:05") + return &s +} + +// formatRetDataFromMemoryListItem converts a MemoryListItem to CreateMemoryResponse +// This function is used for both list and detail memory responses where owner_name is from JOIN query +// +// Parameters: +// - memory: MemoryListItem pointer with owner_name from JOIN +// +// Returns: +// - *CreateMemoryResponse: Formatted response with owner_name populated +// +// Example: +// +// resp := formatRetDataFromMemoryListItem(memoryItem) +func formatRetDataFromMemoryListItem(memory *model.MemoryListItem) *CreateMemoryResponse { + memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType) + resp := &CreateMemoryResponse{ + Memory: memory.Memory, + OwnerName: memory.OwnerName, + MemoryType: memoryTypes, + } + return resp +} diff --git a/internal/service/tenant.go b/internal/service/tenant.go index e4c385856..83e8e016d 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -17,12 +17,14 @@ package service import ( + "strings" "context" "fmt" "time" "ragflow/internal/common" "ragflow/internal/dao" + "ragflow/internal/model" "ragflow/internal/engine" ) @@ -92,6 +94,136 @@ type TenantListItem struct { DeltaSeconds float64 `json:"delta_seconds"` } +// TenantLLMService tenant LLM service +// This service handles operations related to tenant-specific LLM configurations +type TenantLLMService struct { + tenantLLMDAO *dao.TenantLLMDAO +} + +// NewTenantLLMService creates a new TenantLLMService instance +func NewTenantLLMService() *TenantLLMService { + return &TenantLLMService{ + tenantLLMDAO: dao.NewTenantLLMDAO(), + } +} + +// GetAPIKey retrieves the tenant LLM record by tenant ID and model name +/** + * This method splits the model name into name and factory parts using the "@" separator, + * then queries the database for the matching tenant LLM configuration. + * + * Parameters: + * - tenantID: the unique identifier of the tenant + * - modelName: the model name, optionally including factory suffix (e.g., "gpt-4@OpenAI") + * + * Returns: + * - *model.TenantLLM: the tenant LLM record if found, nil otherwise + * - error: an error if the query fails, nil otherwise + * + * Example: + * + * service := NewTenantLLMService() + * + * // Get API key for model with factory + * tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4@OpenAI") + * if err != nil { + * log.Printf("Error: %v", err) + * } + * + * // Get API key for model without factory + * tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4") + */ +func (s *TenantLLMService) GetAPIKey(tenantID, modelName string) (*model.TenantLLM, error) { + modelName, factory := s.SplitModelNameAndFactory(modelName) + + var tenantLLM *model.TenantLLM + var err error + + if factory == "" { + tenantLLM, err = s.tenantLLMDAO.GetByTenantIDAndLLMName(tenantID, modelName) + } else { + tenantLLM, err = s.tenantLLMDAO.GetByTenantIDLLMNameAndFactory(tenantID, modelName, factory) + } + + if err != nil { + return nil, err + } + + return tenantLLM, nil +} + +// SplitModelNameAndFactory splits a model name into name and factory parts +func (s *TenantLLMService) SplitModelNameAndFactory(modelName string) (string, string) { + arr := strings.Split(modelName, "@") + if len(arr) < 2 { + return modelName, "" + } + if len(arr) > 2 { + return strings.Join(arr[0:len(arr)-1], "@"), arr[len(arr)-1] + } + return arr[0], arr[1] +} + +// EnsureTenantModelIDForParams ensures tenant model IDs are populated for LLM-related parameters +/** + * This method iterates through a predefined list of LLM-related parameter keys (llm_id, embd_id, + * asr_id, img2txt_id, rerank_id, tts_id) and automatically populates the corresponding tenant_* + * fields (tenant_llm_id, tenant_embd_id, etc.) with the tenant LLM record IDs. + * + * If a parameter key exists and its corresponding tenant_* key doesn't exist, this method will: + * 1. Query the tenant LLM record using GetAPIKey + * 2. If found, set the tenant_* key to the record's ID + * 3. If not found, set the tenant_* key to 0 + * + * Parameters: + * - tenantID: the unique identifier of the tenant + * - params: a map of parameters to be updated (will be modified in place) + * + * Returns: + * - map[string]interface{}: the updated parameters map (same as input, modified in place) + * + * Example: + * + * service := NewTenantLLMService() + * params := map[string]interface{}{ + * "llm_id": "gpt-4@OpenAI", + * "embd_id": "text-embedding-3-small@OpenAI", + * } + * result := service.EnsureTenantModelIDForParams("tenant-123", params) + * // result will contain: + * // { + * // "llm_id": "gpt-4@OpenAI", + * // "embd_id": "text-embedding-3-small@OpenAI", + * // "tenant_llm_id": 123, // ID from tenant_llm table + * // "tenant_embd_id": 456, // ID from tenant_llm table + * // } + */ +func (s *TenantLLMService) EnsureTenantModelIDForParams(tenantID string, params map[string]interface{}) map[string]interface{} { + paramKeys := []string{"llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"} + + for _, key := range paramKeys { + tenantKey := "tenant_" + key + + if value, exists := params[key]; exists && value != nil && value != "" { + if _, tenantExists := params[tenantKey]; !tenantExists { + modelName, ok := value.(string) + if !ok || modelName == "" { + continue + } + + tenantLLM, err := s.GetAPIKey(tenantID, modelName) + if err == nil && tenantLLM != nil { + params[tenantKey] = tenantLLM.ID + } else { + params[tenantKey] = int64(0) + } + } + } + } + + return params +} + // GetTenantList get tenant list for a user func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) { tenants, err := s.userTenantDAO.GetTenantsByUserID(userID) diff --git a/internal/service/user.go b/internal/service/user.go index 96544bee5..838c473f0 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -924,6 +924,95 @@ func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) er return nil } +// UserTenantService user tenant service +// Provides business logic for user-tenant relationship management +type UserTenantService struct { + userTenantDAO *dao.UserTenantDAO +} + +// NewUserTenantService creates a new UserTenantService instance +/** + * Returns: + * - *UserTenantService: a new UserTenantService instance + * + * Example: + * + * service := NewUserTenantService() + * relations, err := service.GetUserTenantRelationByUserID("user123") + */ +func NewUserTenantService() *UserTenantService { + return &UserTenantService{ + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// UserTenantRelation represents a user-tenant relationship response +// This structure matches the Python implementation's return format +type UserTenantRelation struct { + ID string `json:"id"` + UserID string `json:"user_id"` + TenantID string `json:"tenant_id"` + Role string `json:"role"` +} + +// GetUserTenantRelationByUserID retrieves all user-tenant relationships for a given user ID +/** + * This method returns a list of user-tenant relationships with selected fields: + * - id: the relationship ID + * - user_id: the user ID + * - tenant_id: the tenant ID + * - role: the user's role in the tenant + * + * Parameters: + * - userID: the unique identifier of the user + * + * Returns: + * - []*UserTenantRelation: list of user-tenant relationships + * - error: error if the operation fails, nil otherwise + * + * Example: + * + * service := NewUserTenantService() + * relations, err := service.GetUserTenantRelationByUserID("user123") + * if err != nil { + * log.Printf("Failed to get user tenant relations: %v", err) + * return + * } + * for _, rel := range relations { + * fmt.Printf("User %s has role %s in tenant %s\n", rel.UserID, rel.Role, rel.TenantID) + * } + */ +func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*UserTenantRelation, error) { + relations, err := s.userTenantDAO.GetByUserID(userID) + if err != nil { + return nil, err + } + + result := make([]*UserTenantRelation, len(relations)) + for i, rel := range relations { + result[i] = convertToUserTenantRelation(rel) + } + + return result, nil +} + +// convertToUserTenantRelation converts model.UserTenant to UserTenantRelation +/** + * Parameters: + * - userTenant: the model.UserTenant to convert + * + * Returns: + * - *UserTenantRelation: the converted UserTenantRelation + */ +func convertToUserTenantRelation(userTenant *model.UserTenant) *UserTenantRelation { + return &UserTenantRelation{ + ID: userTenant.ID, + UserID: userTenant.UserID, + TenantID: userTenant.TenantID, + Role: userTenant.Role, + } +} + // GetUserByAPIToken gets user by access key from Authorization header // This is used for API token authentication // The authorization parameter should be in format: "Bearer " or just "" @@ -963,4 +1052,5 @@ func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, comm } return user, common.CodeSuccess, nil + } diff --git a/web/.env.development b/web/.env.development index f33f3bef5..bc3a84770 100644 --- a/web/.env.development +++ b/web/.env.development @@ -1 +1,2 @@ -VITE_BASE_URL='/' \ No newline at end of file +VITE_BASE_URL='/' +API_PROXY_SCHEME='python' \ No newline at end of file diff --git a/web/vite.config.ts b/web/vite.config.ts index 741806ef9..ab87c62a3 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -49,66 +49,18 @@ export default defineConfig(({ mode }) => { }, }, hybrid: { - '/v1/system/token_list': { - target: 'http://127.0.0.1:9384/', - changeOrigin: true, - ws: true, - }, - '/v1/system/new_token': { - target: 'http://127.0.0.1:9384/', - changeOrigin: true, - ws: true, - }, - '/v1/system/token': { - target: 'http://127.0.0.1:9384/', - changeOrigin: true, - ws: true, - }, - '/v1/system/config': { - target: 'http://127.0.0.1:9384/', - changeOrigin: true, - ws: true, - }, - '/v1/user/login': { - target: 'http://127.0.0.1:9384/', - changeOrigin: true, - ws: true, - }, - '/v1/user/logout': { - target: 'http://127.0.0.1:9384/', - changeOrigin: true, - ws: true, - }, - '/api/v1/admin/sandbox': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, - '/api/v1/admin/roles': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, - '/api/v1/admin/roles/owner/permission': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, - '/api/v1/admin/roles_with_permission': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, - '/api/v1/admin/whitelist': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, - '/api/v1/admin/variables': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, + '^(/api/v1/memories)|^(/v1/user/info)|^(/v1/user/tenant_info)|^(/v1/tenant/list)|^(/v1/system/config)|^(/v1/user/login)|^(/v1/user/logout)': + { + target: 'http://127.0.0.1:9384/', + changeOrigin: true, + ws: true, + }, + '^(/api/v1/admin/sandbox)|^(/api/v1/admin/roles)|^(/api/v1/admin/roles/owner/permission)|^(/api/v1/admin/roles_with_permission)|^(/api/v1/admin/whitelist)|^(/api/v1/admin/variables)': + { + target: 'http://127.0.0.1:9381/', + changeOrigin: true, + ws: true, + }, '/api/v1/admin': { target: 'http://127.0.0.1:9383/', changeOrigin: true,