Feat: add memory function by go (#13754)

### What problem does this PR solve?

Feat: Add Memory function by go

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
chanx
2026-03-27 09:49:50 +08:00
committed by GitHub
parent 406339af1f
commit 8a9bbf3d6d
11 changed files with 2350 additions and 62 deletions

View File

@ -175,6 +175,7 @@ func startServer(config *server.Config) {
connectorService := service.NewConnectorService()
searchService := service.NewSearchService()
fileService := service.NewFileService()
memoryService := service.NewMemoryService()
// Initialize handler layer
authHandler := handler.NewAuthHandler()
@ -191,9 +192,10 @@ func startServer(config *server.Config) {
connectorHandler := handler.NewConnectorHandler(connectorService, userService)
searchHandler := handler.NewSearchHandler(searchService, userService)
fileHandler := handler.NewFileHandler(fileService, userService)
memoryHandler := handler.NewMemoryHandler(memoryService)
// Initialize router
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler)
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler)
// Create Gin engine
ginEngine := gin.New()

370
internal/dao/memory.go Normal file
View File

@ -0,0 +1,370 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package dao implements the data access layer
// This file implements Memory-related database operations
// Consistent with Python memory_service.py
package dao
import (
"fmt"
"strings"
"ragflow/internal/model"
)
// Memory type bit flag constants, consistent with Python MemoryType enum
const (
MemoryTypeRaw = 0b0001 // Raw memory (binary: 0001)
MemoryTypeSemantic = 0b0010 // Semantic memory (binary: 0010)
MemoryTypeEpisodic = 0b0100 // Episodic memory (binary: 0100)
MemoryTypeProcedural = 0b1000 // Procedural memory (binary: 1000)
)
// MemoryTypeMap maps memory type names to bit flags
// Exported for use by service package
var MemoryTypeMap = map[string]int{
"raw": MemoryTypeRaw,
"semantic": MemoryTypeSemantic,
"episodic": MemoryTypeEpisodic,
"procedural": MemoryTypeProcedural,
}
// CalculateMemoryType converts memory type names array to bit flags integer
//
// Parameters:
// - memoryTypeNames: Memory type names array
//
// Returns:
// - int64: Bit flags integer
//
// Example:
//
// CalculateMemoryType([]string{"raw", "semantic"}) returns 3 (0b0011)
func CalculateMemoryType(memoryTypeNames []string) int64 {
memoryType := 0
for _, name := range memoryTypeNames {
lowerName := strings.ToLower(name)
if mt, ok := MemoryTypeMap[lowerName]; ok {
memoryType |= mt
}
}
return int64(memoryType)
}
// GetMemoryTypeHuman converts memory type bit flags to human-readable names
//
// Parameters:
// - memoryType: Bit flags integer representing memory types
//
// Returns:
// - []string: Array of human-readable memory type names
//
// Example:
//
// GetMemoryTypeHuman(3) returns ["raw", "semantic"]
func GetMemoryTypeHuman(memoryType int64) []string {
var result []string
if memoryType&int64(MemoryTypeRaw) != 0 {
result = append(result, "raw")
}
if memoryType&int64(MemoryTypeSemantic) != 0 {
result = append(result, "semantic")
}
if memoryType&int64(MemoryTypeEpisodic) != 0 {
result = append(result, "episodic")
}
if memoryType&int64(MemoryTypeProcedural) != 0 {
result = append(result, "procedural")
}
return result
}
// MemoryDAO handles all Memory-related database operations
type MemoryDAO struct{}
// NewMemoryDAO creates a new MemoryDAO instance
//
// Returns:
// - *MemoryDAO: Initialized DAO instance
func NewMemoryDAO() *MemoryDAO {
return &MemoryDAO{}
}
// Create inserts a new memory record into the database
//
// Parameters:
// - memory: Memory model pointer
//
// Returns:
// - error: Database operation error
func (dao *MemoryDAO) Create(memory *model.Memory) error {
return DB.Create(memory).Error
}
// GetByID retrieves a memory record by ID from database
//
// Parameters:
// - id: Memory ID
//
// Returns:
// - *model.Memory: Memory model pointer
// - error: Database operation error
func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) {
var memory model.Memory
err := DB.Where("id = ?", id).First(&memory).Error
if err != nil {
return nil, err
}
return &memory, nil
}
// GetByTenantID retrieves all memories for a tenant
//
// Parameters:
// - tenantID: Tenant ID
//
// Returns:
// - []*model.Memory: Memory model pointer array
// - error: Database operation error
func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) {
var memories []*model.Memory
err := DB.Where("tenant_id = ?", tenantID).Find(&memories).Error
return memories, err
}
// GetByNameAndTenant checks if memory exists by name and tenant ID
// Used for duplicate name deduplication
//
// Parameters:
// - name: Memory name
// - tenantID: Tenant ID
//
// Returns:
// - []*model.Memory: Matching memory list (for existence check)
// - error: Database operation error
func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model.Memory, error) {
var memories []*model.Memory
err := DB.Where("name = ? AND tenant_id = ?", name, tenantID).Find(&memories).Error
return memories, err
}
// GetByIDs retrieves memories by multiple IDs
//
// Parameters:
// - ids: Memory ID list
//
// Returns:
// - []*model.Memory: Memory model pointer array
// - error: Database operation error
func (dao *MemoryDAO) GetByIDs(ids []string) ([]*model.Memory, error) {
var memories []*model.Memory
err := DB.Where("id IN ?", ids).Find(&memories).Error
return memories, err
}
// UpdateByID updates a memory by ID
// Supports partial updates - only updates passed fields
// Automatically handles field type conversions
//
// Parameters:
// - id: Memory ID
// - updates: Fields to update map
//
// Returns:
// - error: Database operation error
//
// Field type handling:
// - memory_type: []string converts to bit flags integer
// - temperature: string converts to float64
// - name: Uses string value directly
// - permissions, forgetting_policy: Uses string value directly
//
// Example:
//
// updates := map[string]interface{}{"name": "NewName", "memory_type": []string{"semantic"}}
// err := dao.UpdateByID("memory123", updates)
func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) error {
if updates == nil || len(updates) == 0 {
return nil
}
for key, value := range updates {
switch key {
case "memory_type":
if types, ok := value.([]string); ok {
updates[key] = CalculateMemoryType(types)
}
case "temperature":
if tempStr, ok := value.(string); ok {
var temp float64
fmt.Sscanf(tempStr, "%f", &temp)
updates[key] = temp
}
}
}
return DB.Model(&model.Memory{}).Where("id = ?", id).Updates(updates).Error
}
// DeleteByID deletes a memory by ID
//
// Parameters:
// - id: Memory ID
//
// Returns:
// - error: Database operation error
//
// Example:
//
// err := dao.DeleteByID("memory123")
func (dao *MemoryDAO) DeleteByID(id string) error {
return DB.Where("id = ?", id).Delete(&model.Memory{}).Error
}
// GetWithOwnerNameByID retrieves a memory with owner name by ID
// Joins with User table to get owner's nickname
//
// Parameters:
// - id: Memory ID
//
// Returns:
// - *model.MemoryListItem: Memory detail with owner name populated
// - error: Database operation error
//
// Example:
//
// memory, err := dao.GetWithOwnerNameByID("memory123")
func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, error) {
querySQL := `
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
m.permissions, m.description, m.memory_size, m.forgetting_policy,
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
m.update_time, m.update_date,
u.nickname as owner_name
FROM memory m
LEFT JOIN user u ON m.tenant_id = u.id
WHERE m.id = ?
`
var rawResult struct {
model.Memory
OwnerName *string `gorm:"column:owner_name"`
}
if err := DB.Raw(querySQL, id).Scan(&rawResult).Error; err != nil {
return nil, err
}
return &model.MemoryListItem{
Memory: rawResult.Memory,
OwnerName: rawResult.OwnerName,
}, nil
}
// GetByFilter retrieves memories with optional filters
// Supports filtering by tenant_id, memory_type, storage_type, and keywords
// Returns paginated results with owner_name from user table JOIN
//
// Parameters:
// - tenantIDs: Array of tenant IDs to filter by (empty means all tenants)
// - memoryTypes: Array of memory type names to filter by (empty means all types)
// - storageType: Storage type to filter by (empty means all types)
// - keywords: Keywords to search in memory names (empty means no keyword filter)
// - page: Page number (1-based)
// - pageSize: Number of items per page
//
// Returns:
// - []*model.MemoryListItem: Memory list items with owner name populated
// - int64: Total count of matching memories
// - error: Database operation error
//
// Example:
//
// memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10)
func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*model.MemoryListItem, int64, error) {
var conditions []string
var args []interface{}
if len(tenantIDs) > 0 {
conditions = append(conditions, "m.tenant_id IN ?")
args = append(args, tenantIDs)
}
if len(memoryTypes) > 0 {
memoryTypeInt := CalculateMemoryType(memoryTypes)
conditions = append(conditions, "m.memory_type & ? > 0")
args = append(args, memoryTypeInt)
}
if storageType != "" {
conditions = append(conditions, "m.storage_type = ?")
args = append(args, storageType)
}
if keywords != "" {
conditions = append(conditions, "m.name LIKE ?")
args = append(args, "%"+keywords+"%")
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM memory m %s", whereClause)
var total int64
if err := DB.Raw(countSQL, args...).Scan(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
querySQL := fmt.Sprintf(`
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
m.permissions, m.description, m.memory_size, m.forgetting_policy,
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
m.update_time, m.update_date,
u.nickname as owner_name
FROM memory m
LEFT JOIN user u ON m.tenant_id = u.id
%s
ORDER BY m.update_time DESC
LIMIT ? OFFSET ?
`, whereClause)
queryArgs := append(args, pageSize, offset)
var rawResults []struct {
model.Memory
OwnerName *string `gorm:"column:owner_name"`
}
if err := DB.Raw(querySQL, queryArgs...).Scan(&rawResults).Error; err != nil {
return nil, 0, err
}
memories := make([]*model.MemoryListItem, len(rawResults))
for i, r := range rawResults {
memories[i] = &model.MemoryListItem{
Memory: r.Memory,
OwnerName: r.OwnerName,
}
}
return memories, total, nil
}

View File

@ -141,3 +141,130 @@ func (dao *TenantLLMDAO) DeleteByTenantID(tenantID string) (int64, error) {
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.TenantLLM{})
return result.RowsAffected, result.Error
}
// splitModelNameAndFactory splits model name and factory from combined format
// This matches Python's split_model_name_and_factory logic
//
// Parameters:
// - modelName: The model name which can be in format "ModelName" or "ModelName@Factory"
//
// Returns:
// - string: The model name without factory prefix
// - string: The factory name (empty string if not specified)
//
// Example:
//
// modelName, factory := splitModelNameAndFactory("gpt-4")
// // Returns: "gpt-4", ""
//
// modelName, factory := splitModelNameAndFactory("gpt-4@OpenAI")
// // Returns: "gpt-4", "OpenAI"
func splitModelNameAndFactory(modelName string) (string, string) {
// Split by "@" separator
// Handle cases like "model@factory" or "model@sub@factory"
lastAtIndex := -1
for i := len(modelName) - 1; i >= 0; i-- {
if modelName[i] == '@' {
lastAtIndex = i
break
}
}
// No "@" found, return original name
if lastAtIndex == -1 {
return modelName, ""
}
// Split into model name and potential factory
modelNamePart := modelName[:lastAtIndex]
factory := modelName[lastAtIndex+1:]
// Validate if factory exists in llm_factories table
// This matches Python's logic of checking against model providers
var factoryCount int64
DB.Model(&model.LLMFactories{}).Where("name = ?", factory).Count(&factoryCount)
// If factory doesn't exist in database, treat the whole string as model name
if factoryCount == 0 {
return modelName, ""
}
return modelNamePart, factory
}
// GetByTenantIDAndLLMName gets tenant LLM by tenant ID and LLM name
// This is used to resolve tenant_llm_id from llm_id
// It supports both simple model names and factory-prefixed names (e.g., "gpt-4@OpenAI")
//
// Parameters:
// - tenantID: The tenant identifier
// - llmName: The LLM model name (can include factory prefix like "OpenAI@gpt-4")
//
// Returns:
// - *model.TenantLLM: The tenant LLM record
// - error: Error if not found
//
// Example:
//
// // Simple model name
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4")
//
// // Model name with factory prefix
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4@OpenAI")
func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string) (*model.TenantLLM, error) {
var tenantLLM model.TenantLLM
// Split model name and factory from the combined format
modelName, factory := splitModelNameAndFactory(llmName)
// First attempt: try to find with model name only
err := DB.Where("tenant_id = ? AND llm_name = ?", tenantID, modelName).First(&tenantLLM).Error
if err == nil {
return &tenantLLM, nil
}
// Second attempt: if factory is specified, try with both model name and factory
if factory != "" {
err = DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, modelName, factory).First(&tenantLLM).Error
if err == nil {
return &tenantLLM, nil
}
// Special handling for LocalAI and HuggingFace (matching Python logic)
// These factories append "___FactoryName" to the model name
if factory == "LocalAI" || factory == "HuggingFace" || factory == "OpenAI-API-Compatible" {
specialModelName := modelName + "___" + factory
err = DB.Where("tenant_id = ? AND llm_name = ?", tenantID, specialModelName).First(&tenantLLM).Error
if err == nil {
return &tenantLLM, nil
}
}
}
// Return the last error (record not found)
return nil, err
}
// GetByTenantIDLLMNameAndFactory gets tenant LLM by tenant ID, LLM name and factory
// This is used when model name includes factory suffix (e.g., "model@factory")
//
// Parameters:
// - tenantID: The tenant identifier
// - llmName: The LLM model name
// - factory: The LLM factory name
//
// Returns:
// - *model.TenantLLM: The tenant LLM record
// - error: Error if not found
//
// Example:
//
// tenantLLM, err := dao.GetByTenantIDLLMNameAndFactory("tenant123", "gpt-4", "OpenAI")
func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, factory string) (*model.TenantLLM, error) {
var tenantLLM model.TenantLLM
err := DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, llmName, factory).First(&tenantLLM).Error
if err != nil {
return nil, err
}
return &tenantLLM, nil
}

687
internal/handler/memory.go Normal file
View File

@ -0,0 +1,687 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package handler contains all HTTP request handlers
// This file implements Memory-related API endpoint handlers
// Each method corresponds to an API endpoint in the Python memory_api.py
package handler
import (
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/service"
)
// MemoryHandler handles Memory-related HTTP requests
// Responsible for processing all Memory-related HTTP requests
// Each method corresponds to an API endpoint, implementing the same logic as Python memory_api.py
type MemoryHandler struct {
memoryService *service.MemoryService // Reference to Memory business service layer
}
// NewMemoryHandler creates a new MemoryHandler instance
//
// Parameters:
// - memoryService: Pointer to MemoryService business service layer
//
// Returns:
// - *MemoryHandler: Initialized handler instance
func NewMemoryHandler(memoryService *service.MemoryService) *MemoryHandler {
return &MemoryHandler{
memoryService: memoryService,
}
}
// CreateMemory handles POST request for creating Memory
// API Path: POST /api/v1/memories
//
// Function:
// - Creates a new memory record
// - Supports automatic system_prompt generation
// - Supports name deduplication (if name exists, adds sequence number)
//
// Request Parameters (JSON Body):
// - name (required): Memory name, max 128 characters
// - memory_type (required): Memory type array, supports ["raw", "semantic", "episodic", "procedural"]
// - embd_id (required): Embedding model ID
// - llm_id (required): LLM model ID
// - tenant_embd_id (optional): Tenant embedding model ID
// - tenant_llm_id (optional): Tenant LLM model ID
//
// Response Format:
// - code: Status code (0=success, other=error)
// - message: true on success, error message on failure
// - data: Memory object on success
//
// Business Logic (matching Python create_memory):
// 1. Validate user login status
// 2. Parse and validate request parameters
// 3. Call service layer to create memory
// 4. Return creation result
func (h *MemoryHandler) CreateMemory(c *gin.Context) {
// Check if API timing is enabled
// If RAGFLOW_API_TIMING environment variable is set, request processing time will be logged
timingEnabled := os.Getenv("RAGFLOW_API_TIMING")
var tStart time.Time
if timingEnabled != "" {
tStart = time.Now()
}
// Get current logged-in user information
// GetUser is a context value set by the authentication middleware
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse JSON request body
var req service.CreateMemoryRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeBadRequest,
"message": err.Error(),
"data": nil,
})
return
}
// Validate required field: name
if req.Name == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "name is required",
"data": nil,
})
return
}
// Validate required field: memory_type (must be non-empty array)
if len(req.MemoryType) == 0 {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "memory_type is required and must be a list",
"data": nil,
})
return
}
// Validate required field: embd_id
if req.EmbdID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "embd_id is required",
"data": nil,
})
return
}
// Validate required field: llm_id
if req.LLMID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "llm_id is required",
"data": nil,
})
return
}
// Record request parsing completion time (for timing)
tParsed := time.Now()
// Call service layer to create memory
result, err := h.memoryService.CreateMemory(userID, &req)
if err != nil {
// Log error if timing is enabled
if timingEnabled != "" {
totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0
parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0
_ = parseMs
_ = totalMs
}
errMsg := err.Error()
// Determine if it's an argument error and return appropriate error code
if isArgumentError(errMsg) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": errMsg,
"data": nil,
})
return
}
// Other errors return server error
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": errMsg,
"data": nil,
})
return
}
// Log success if timing is enabled
if timingEnabled != "" {
totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0
parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0
validateAndDbMs := totalMs - parseMs
_ = parseMs
_ = validateAndDbMs
_ = totalMs
}
// Return success response
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": true,
"data": result,
})
}
// UpdateMemory handles PUT request for updating Memory
// API Path: PUT /api/v1/memories/:memory_id
//
// Function:
// - Updates configuration information for the specified memory
// - Supports partial updates: only update passed fields
//
// Request Parameters (JSON Body):
// - name (optional): Memory name
// - permissions (optional): Permission setting ["me", "team", "all"]
// - llm_id (optional): LLM model ID
// - embd_id (optional): Embedding model ID
// - tenant_llm_id (optional): Tenant LLM model ID
// - tenant_embd_id (optional): Tenant embedding model ID
// - memory_type (optional): Memory type array
// - memory_size (optional): Memory size, range (0, 5242880]
// - forgetting_policy (optional): Forgetting policy, default "FIFO"
// - temperature (optional): Temperature parameter, range [0, 1]
// - avatar (optional): Avatar URL
// - description (optional): Description
// - system_prompt (optional): System prompt
// - user_prompt (optional): User prompt
//
// Business Rules:
// - name length <= 128 characters
// - Cannot update tenant_embd_id, embd_id, memory_type when memory_size > 0
// - When updating memory_type, system_prompt is automatically regenerated if it's the default
func (h *MemoryHandler) UpdateMemory(c *gin.Context) {
// Get memory_id from URL path
memoryID := c.Param("memory_id")
if memoryID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "memory_id is required",
"data": nil,
})
return
}
// Get current logged-in user information
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse JSON request body
var req service.UpdateMemoryRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeBadRequest,
"message": err.Error(),
"data": nil,
})
return
}
// Call service layer to update memory
result, err := h.memoryService.UpdateMemory(userID, memoryID, &req)
if err != nil {
errMsg := err.Error()
// Check if it's a "not found" error
if strings.Contains(errMsg, "not found") {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeNotFound,
"message": errMsg,
"data": nil,
})
return
}
// Check if it's an argument error
if isArgumentError(errMsg) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": errMsg,
"data": nil,
})
return
}
// Other errors return server error
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": errMsg,
"data": nil,
})
return
}
// Return success response
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": true,
"data": result,
})
}
// DeleteMemory handles DELETE request for deleting Memory
// API Path: DELETE /api/v1/memories/:memory_id
//
// Function:
// - Deletes the specified memory record
// - Also deletes associated message data
//
// Business Logic:
// 1. Check if memory exists
// 2. Delete memory record
// 3. Delete associated message index
func (h *MemoryHandler) DeleteMemory(c *gin.Context) {
// Get memory_id from URL path
memoryID := c.Param("memory_id")
if memoryID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "memory_id is required",
"data": nil,
})
return
}
// Call service layer to delete memory
err := h.memoryService.DeleteMemory(memoryID)
if err != nil {
errMsg := err.Error()
// Check if it's a "not found" error
if strings.Contains(errMsg, "not found") {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeNotFound,
"message": errMsg,
"data": nil,
})
return
}
// Other errors return server error
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": errMsg,
"data": nil,
})
return
}
// Return success response
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": true,
"data": nil,
})
}
// ListMemories handles GET request for listing Memories
// API Path: GET /api/v1/memories
//
// Function:
// - Lists memories accessible to the current user
// - Supports multiple filter conditions
// - Supports pagination and keyword search
//
// Query Parameters:
// - memory_type (optional): Memory type filter, supports comma-separated multiple types
// - tenant_id (optional): Tenant ID filter
// - storage_type (optional): Storage type filter
// - keywords (optional): Keyword search (fuzzy match on name)
// - page (optional): Page number, default 1
// - page_size (optional): Items per page, default 50
//
// Response Format:
// - code: Status code
// - message: true
// - data.memory_list: Array of Memory objects
// - data.total_count: Total record count
func (h *MemoryHandler) ListMemories(c *gin.Context) {
// Get current logged-in user information
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
// Parse query parameters
memoryTypesParam := c.Query("memory_type")
tenantIDsParam := c.Query("tenant_id")
storageType := c.Query("storage_type")
keywords := c.Query("keywords")
pageStr := c.DefaultQuery("page", "1")
pageSizeStr := c.DefaultQuery("page_size", "50")
// Convert pagination parameters to integers
page, _ := strconv.Atoi(pageStr)
pageSize, _ := strconv.Atoi(pageSizeStr)
// Validate pagination parameters
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 50
}
// Parse memory_type parameter (supports comma separation)
var memoryTypes []string
if memoryTypesParam != "" {
if strings.Contains(memoryTypesParam, ",") {
memoryTypes = strings.Split(memoryTypesParam, ",")
} else {
memoryTypes = []string{memoryTypesParam}
}
}
// Parse tenant_id parameter
// If not specified, service will get all tenants associated with the user
var tenantIDs []string
if tenantIDsParam != "" {
if strings.Contains(tenantIDsParam, ",") {
tenantIDs = strings.Split(tenantIDsParam, ",")
} else {
tenantIDs = []string{tenantIDsParam}
}
}
// Call service layer to get memory list
result, err := h.memoryService.ListMemories(user.ID, tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": err.Error(),
"data": nil,
})
return
}
// Return success response
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": true,
"data": result,
})
}
// GetMemoryConfig handles GET request for getting Memory configuration
// API Path: GET /api/v1/memories/:memory_id/config
//
// Function:
// - Gets complete configuration information for the specified memory
// - Includes owner name (obtained via JOIN with user table)
//
// Response Format:
// - code: Status code
// - message: true
// - data: Memory object, including owner_name field
func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) {
// Get memory_id from URL path
memoryID := c.Param("memory_id")
if memoryID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "memory_id is required",
"data": nil,
})
return
}
// Call service layer to get memory configuration
result, err := h.memoryService.GetMemoryConfig(memoryID)
if err != nil {
errMsg := err.Error()
// Check if it's a "not found" error
if strings.Contains(errMsg, "not found") {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeNotFound,
"message": errMsg,
"data": nil,
})
return
}
// Other errors return server error
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": errMsg,
"data": nil,
})
return
}
// Return success response
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": true,
"data": result,
})
}
// GetMemoryMessages handles GET request for getting Memory messages
// API Path: GET /api/v1/memories/:memory_id
//
// Function:
// - Gets message list associated with the specified memory
// - Supports filtering by agent_id
// - Supports keyword search and pagination
//
// Query Parameters:
// - agent_id (optional): Agent ID filter, supports comma-separated multiple
// - keywords (optional): Keyword search
// - page (optional): Page number, default 1
// - page_size (optional): Items per page, default 50
//
// Response Format:
// - code: Status code
// - message: true
// - data.messages: Array of message objects
// - data.storage_type: Storage type
//
// TODO: Implementation pending - depends on CanvasService and TaskService
func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies",
"data": nil,
})
}
// AddMessage handles POST request for adding messages
// API Path: POST /api/v1/messages
//
// Function:
// - Adds messages to one or more memories
// - Messages will be embedded and saved to vector database
// - Creates asynchronous task for processing
//
// Request Parameters (JSON Body):
// - memory_id (required): Memory ID or ID array
// - agent_id (required): Agent ID
// - session_id (required): Session ID
// - user_input (required): User input
// - agent_response (required): Agent response
// - user_id (optional): User ID
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) AddMessage(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "AddMessage not implemented - pending embedding engine dependency",
"data": nil,
})
}
// ForgetMessage handles DELETE request for forgetting messages
// API Path: DELETE /api/v1/messages/:memory_id/:message_id
//
// Function:
// - Soft-deletes the specified message (sets forget_at timestamp)
// - Message is not immediately deleted from database, but marked as "forgotten"
//
// Parameter Format:
// - memory_id: Memory ID
// - message_id: Message ID (integer)
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) ForgetMessage(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "ForgetMessage not implemented - pending embedding engine dependency",
"data": nil,
})
}
// UpdateMessage handles PUT request for updating message status
// API Path: PUT /api/v1/messages/:memory_id/:message_id
//
// Function:
// - Updates status of the specified message
// - status is a boolean, converted to integer for storage (true=1, false=0)
//
// Request Parameters (JSON Body):
// - status (required): Message status, boolean
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "UpdateMessage not implemented - pending embedding engine dependency",
"data": nil,
})
}
// SearchMessage handles GET request for searching messages
// API Path: GET /api/v1/messages/search
//
// Function:
// - Searches messages across multiple memories
// - Supports vector similarity search and keyword search
// - Fuses results from both search methods
//
// Query Parameters:
// - memory_id (optional): Memory ID list, supports comma separation
// - query (optional): Search query text
// - similarity_threshold (optional): Similarity threshold, default 0.2
// - keywords_similarity_weight (optional): Keyword weight, default 0.7
// - top_n (optional): Number of results to return, default 5
// - agent_id (optional): Agent ID filter
// - session_id (optional): Session ID filter
// - user_id (optional): User ID filter
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) SearchMessage(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "SearchMessage not implemented - pending embedding engine dependency",
"data": nil,
})
}
// GetMessages handles GET request for getting message list
// API Path: GET /api/v1/messages
//
// Function:
// - Gets recent messages from specified memories
// - Supports filtering by agent_id and session_id
//
// Query Parameters:
// - memory_id (required): Memory ID list, supports comma separation
// - agent_id (optional): Agent ID filter
// - session_id (optional): Session ID filter
// - limit (optional): Number of results to return, default 10
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) GetMessages(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "GetMessages not implemented - pending embedding engine dependency",
"data": nil,
})
}
// GetMessageContent handles GET request for getting message content
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
//
// Function:
// - Gets complete content of the specified message
// - doc_id format: memory_id + "_" + message_id
//
// Parameter Format:
// - memory_id: Memory ID
// - message_id: Message ID (integer)
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "GetMessageContent not implemented - pending embedding engine dependency",
"data": nil,
})
}
// isArgumentError determines if an error message is an argument error
//
// Function:
// - Checks if the error message contains any argument validation-related prefixes
// - Used to distinguish argument errors from server errors
//
// Parameters:
// - msg: Error message string
//
// Returns:
// - bool: true if it's an argument error, false otherwise
func isArgumentError(msg string) bool {
// Define list of argument error prefixes
// Matches Python ArgumentException error messages
argumentErrorPrefixes := []string{
"memory name cannot be empty", // Memory name cannot be empty
"memory name exceeds limit", // Memory name exceeds limit
"memory type must be a list", // memory_type must be a list
"memory type is not supported", // Unsupported memory_type
}
// Check if error message starts with any prefix
for _, prefix := range argumentErrorPrefixes {
if len(msg) >= len(prefix) && msg[:len(prefix)] == prefix {
return true
}
}
return false
}

View File

@ -42,3 +42,12 @@ type Memory struct {
func (Memory) TableName() string {
return "memory"
}
// MemoryListItem represents a memory record with owner name from JOIN query.
// Uses struct embedding to extend Memory struct with owner_name from user table JOIN.
// Note: MemoryType is kept as int64 from Memory embedding; conversion to []string
// happens in the Service layer via CreateMemoryResponse.
type MemoryListItem struct {
Memory
OwnerName *string `json:"owner_name,omitempty"`
}

View File

@ -38,6 +38,7 @@ type Router struct {
connectorHandler *handler.ConnectorHandler
searchHandler *handler.SearchHandler
fileHandler *handler.FileHandler
memoryHandler *handler.MemoryHandler
}
// NewRouter create router
@ -56,6 +57,7 @@ func NewRouter(
connectorHandler *handler.ConnectorHandler,
searchHandler *handler.SearchHandler,
fileHandler *handler.FileHandler,
memoryHandler *handler.MemoryHandler,
) *Router {
return &Router{
authHandler: authHandler,
@ -72,6 +74,7 @@ func NewRouter(
connectorHandler: connectorHandler,
searchHandler: searchHandler,
fileHandler: fileHandler,
memoryHandler: memoryHandler,
}
}
@ -163,6 +166,28 @@ func (r *Router) Setup(engine *gin.Engine) {
{
authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID)
}
// Memory routes
memory := v1.Group("/memories")
{
memory.POST("", r.memoryHandler.CreateMemory)
memory.PUT("/:memory_id", r.memoryHandler.UpdateMemory)
memory.DELETE("/:memory_id", r.memoryHandler.DeleteMemory)
memory.GET("", r.memoryHandler.ListMemories)
memory.GET("/:memory_id/config", r.memoryHandler.GetMemoryConfig)
memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages)
}
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
// message := v1.Group("/messages")
// {
// message.POST("", r.memoryHandler.AddMessage)
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
// message.GET("/search", r.memoryHandler.SearchMessage)
// message.GET("", r.memoryHandler.GetMessages)
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
// }
}
// Knowledge base routes
@ -260,6 +285,7 @@ func (r *Router) Setup(engine *gin.Engine) {
file.GET("/parent_folder", r.fileHandler.GetParentFolder)
file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders)
}
}
// Handle undefined routes

892
internal/service/memory.go Normal file
View File

@ -0,0 +1,892 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"errors"
"fmt"
"path"
"regexp"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"ragflow/internal/dao"
"ragflow/internal/model"
)
const (
// MemoryNameLimit is the maximum length allowed for memory names
MemoryNameLimit = 128
// MemorySizeLimit is the maximum memory size in bytes (5MB)
MemorySizeLimit = 5242880
)
// Note: MemoryType, MemoryTypeRaw, MemoryTypeSemantic, MemoryTypeEpisodic,
// MemoryTypeProcedural, and CalculateMemoryType are defined in the dao package
// and imported as dao.MemoryType, dao.MemoryTypeRaw, etc.
// TenantPermission defines the access permission levels for memory resources
// Note: This type is specific to the service layer
type TenantPermission string
const (
// TenantPermissionMe restricts access to the owner only
TenantPermissionMe TenantPermission = "me"
// TenantPermissionTeam allows access within the same team
TenantPermissionTeam TenantPermission = "team"
// TenantPermissionAll allows access to all tenants
TenantPermissionAll TenantPermission = "all"
)
// validPermissions defines which permission values are valid
var validPermissions = map[TenantPermission]bool{
TenantPermissionMe: true,
TenantPermissionTeam: true,
TenantPermissionAll: true,
}
// ForgettingPolicy defines the strategy for forgetting old memory entries
type ForgettingPolicy string
const (
// ForgettingPolicyFIFO uses First-In-First-Out strategy for forgetting
ForgettingPolicyFIFO ForgettingPolicy = "FIFO"
)
// validForgettingPolicies defines which forgetting policies are valid
var validForgettingPolicies = map[ForgettingPolicy]bool{
ForgettingPolicyFIFO: true,
}
//
// Note: CalculateMemoryType and GetMemoryTypeHuman functions have been moved to dao package
// Use dao.CalculateMemoryType() and dao.GetMemoryTypeHuman() instead
// PromptAssembler handles the assembly of system prompts for memory extraction
type PromptAssembler struct{}
// SYSTEM_BASE_TEMPLATE is the base template for the system prompt used in memory extraction
// It includes placeholders for type-specific instructions, timestamp format, and max items
var SYSTEM_BASE_TEMPLATE = `**Memory Extraction Specialist**
You are an expert at analyzing conversations to extract structured memory.
{type_specific_instructions}
**OUTPUT REQUIREMENTS:**
1. Output MUST be valid JSON
2. Follow the specified output format exactly
3. Each extracted item MUST have: content, valid_at, invalid_at
4. Timestamps in {timestamp_format} format
5. Only extract memory types specified above
6. Maximum {max_items} items per type
`
// TYPE_INSTRUCTIONS contains specific instructions for each memory type extraction
var TYPE_INSTRUCTIONS = map[string]string{
"semantic": `
**EXTRACT SEMANTIC KNOWLEDGE:**
- Universal facts, definitions, concepts, relationships
- Time-invariant, generally true information
**Timestamp Rules:**
- valid_at: When the fact became true
- invalid_at: When it becomes false or empty if still true
`,
"episodic": `
**EXTRACT EPISODIC KNOWLEDGE:**
- Specific experiences, events, personal stories
- Time-bound, person-specific, contextual
**Timestamp Rules:**
- valid_at: Event start/occurrence time
- invalid_at: Event end time or empty if instantaneous
`,
"procedural": `
**EXTRACT PROCEDURAL KNOWLEDGE:**
- Processes, methods, step-by-step instructions
- Goal-oriented, actionable, often includes conditions
**Timestamp Rules:**
- valid_at: When procedure becomes valid/effective
- invalid_at: When it expires/becomes obsolete or empty if current
`,
}
// OUTPUT_TEMPLATES defines the output format for each memory type
var OUTPUT_TEMPLATES = map[string]string{
"semantic": `"semantic": [{"content": "Clear factual statement", "valid_at": "timestamp or empty", "invalid_at": "timestamp or empty"}]`,
"episodic": `"episodic": [{"content": "Narrative event description", "valid_at": "event start timestamp", "invalid_at": "event end timestamp or empty"}]`,
"procedural": `"procedural": [{"content": "Actionable instructions", "valid_at": "procedure effective timestamp", "invalid_at": "procedure expiration timestamp or empty"}]`,
}
// AssembleSystemPrompt generates a complete system prompt for memory extraction
//
// Parameters:
// - memoryTypes: Array of memory type names to extract (e.g., ["semantic", "episodic"])
//
// Returns:
// - string: Complete system prompt with type-specific instructions and output format
//
// Example:
//
// AssembleSystemPrompt([]string{"semantic", "episodic"}) returns a prompt with instructions
// for both semantic and episodic memory extraction
func (PromptAssembler) AssembleSystemPrompt(memoryTypes []string) string {
typesToExtract := getTypesToExtract(memoryTypes)
if len(typesToExtract) == 0 {
typesToExtract = []string{"raw"}
}
typeInstructions := generateTypeInstructions(typesToExtract)
outputFormat := generateOutputFormat(typesToExtract)
fullPrompt := strings.Replace(SYSTEM_BASE_TEMPLATE, "{type_specific_instructions}", typeInstructions, 1)
fullPrompt = strings.Replace(fullPrompt, "{timestamp_format}", "ISO 8601", 1)
fullPrompt = strings.Replace(fullPrompt, "{max_items}", "5", 1)
fullPrompt += fmt.Sprintf("\n**REQUIRED OUTPUT FORMAT (JSON):\n```json\n{\n%s\n}\n```\n", outputFormat)
return fullPrompt
}
// getTypesToExtract filters out "raw" type and returns valid memory types
//
// Parameters:
// - requestedTypes: Array of requested memory type names
//
// Returns:
// - []string: Filtered array of memory type names (excluding "raw")
func getTypesToExtract(requestedTypes []string) []string {
types := make(map[string]bool)
for _, rt := range requestedTypes {
lowerRT := strings.ToLower(rt)
if lowerRT != "raw" {
if _, ok := dao.MemoryTypeMap[lowerRT]; ok {
types[lowerRT] = true
}
}
}
result := make([]string, 0, len(types))
for t := range types {
result = append(result, t)
}
return result
}
// generateTypeInstructions concatenates type-specific instructions
//
// Parameters:
// - typesToExtract: Array of memory type names
//
// Returns:
// - string: Concatenated instructions for all specified types
func generateTypeInstructions(typesToExtract []string) string {
var instructions []string
for _, mt := range typesToExtract {
if instr, ok := TYPE_INSTRUCTIONS[mt]; ok {
instructions = append(instructions, instr)
}
}
return strings.Join(instructions, "\n")
}
// generateOutputFormat concatenates output format templates
//
// Parameters:
// - typesToExtract: Array of memory type names
//
// Returns:
// - string: Concatenated output format templates
func generateOutputFormat(typesToExtract []string) string {
var outputParts []string
for _, mt := range typesToExtract {
if tmpl, ok := OUTPUT_TEMPLATES[mt]; ok {
outputParts = append(outputParts, tmpl)
}
}
return strings.Join(outputParts, ",\n")
}
// MemoryService handles business logic for memory operations
// It provides methods for creating, updating, deleting, and querying memories
type MemoryService struct {
memoryDAO *dao.MemoryDAO
}
// NewMemoryService creates a new MemoryService instance
//
// Returns:
// - *MemoryService: Initialized service instance with DAO
func NewMemoryService() *MemoryService {
return &MemoryService{
memoryDAO: dao.NewMemoryDAO(),
}
}
// splitNameCounter splits a filename into base name and counter
// Handles names in format "filename(123)" pattern
//
// Parameters:
// - filename: The filename to split
//
// Returns:
// - string: The base name without counter
// - *int: The counter value, or nil if no counter exists
//
// Example:
//
// splitNameCounter("test(5)") returns ("test", 5)
// splitNameCounter("test") returns ("test", nil)
func splitNameCounter(filename string) (string, *int) {
re := regexp.MustCompile(`^(.+)\((\d+)\)$`)
matches := re.FindStringSubmatch(filename)
if len(matches) >= 3 {
counter := -1
fmt.Sscanf(matches[2], "%d", &counter)
stem := strings.TrimRight(matches[1], " ")
return stem, &counter
}
return filename, nil
}
// duplicateName generates a unique name by appending a counter if the name already exists
// It tries up to 1000 times to generate a unique name
//
// Parameters:
// - queryFunc: Function to check if a name already exists (returns true if exists)
// - name: The original name
// - tenantID: The tenant ID for name uniqueness check
//
// Returns:
// - string: A unique name (either original or with counter appended)
//
// Example:
//
// duplicateName(func(name string, tid string) bool { return false }, "test", "tenant1") returns "test"
// duplicateName(func(name string, tid string) bool { return true }, "test", "tenant1") returns "test(1)"
func duplicateName(queryFunc func(name string, tenantID string) bool, name string, tenantID string) string {
const maxRetries = 1000
originalName := name
currentName := name
retries := 0
for retries < maxRetries {
if !queryFunc(currentName, tenantID) {
return currentName
}
stem, counter := splitNameCounter(currentName)
ext := path.Ext(stem)
stemBase := strings.TrimSuffix(stem, ext)
newCounter := 1
if counter != nil {
newCounter = *counter + 1
}
currentName = fmt.Sprintf("%s(%d)%s", stemBase, newCounter, ext)
retries++
}
panic(fmt.Sprintf("Failed to generate unique name within %d attempts. Original: %s", maxRetries, originalName))
}
// CreateMemoryRequest defines the request structure for creating a memory
type CreateMemoryRequest struct {
// Name is the memory name (required, max 128 characters)
Name string `json:"name" binding:"required"`
// MemoryType is the array of memory type names (required)
MemoryType []string `json:"memory_type" binding:"required"`
// EmbdID is the embedding model ID (required)
EmbdID string `json:"embd_id" binding:"required"`
// LLMID is the language model ID (required)
LLMID string `json:"llm_id" binding:"required"`
// TenantEmbdID is the tenant-specific embedding model ID (optional)
TenantEmbdID *string `json:"tenant_embd_id"`
// TenantLLMID is the tenant-specific language model ID (optional)
TenantLLMID *string `json:"tenant_llm_id"`
}
// UpdateMemoryRequest defines the request structure for updating a memory
// All fields are optional, only provided fields will be updated
type UpdateMemoryRequest struct {
// Name is the new memory name (optional)
Name *string `json:"name"`
// Permissions is the new permission level (optional)
Permissions *string `json:"permissions"`
// LLMID is the new language model ID (optional)
LLMID *string `json:"llm_id"`
// EmbdID is the new embedding model ID (optional)
EmbdID *string `json:"embd_id"`
// TenantLLMID is the new tenant-specific language model ID (optional)
TenantLLMID *string `json:"tenant_llm_id"`
// TenantEmbdID is the new tenant-specific embedding model ID (optional)
TenantEmbdID *string `json:"tenant_embd_id"`
// MemoryType is the new array of memory type names (optional)
MemoryType []string `json:"memory_type"`
// MemorySize is the new memory size in bytes (optional, max 5MB)
MemorySize *int64 `json:"memory_size"`
// ForgettingPolicy is the new forgetting policy (optional)
ForgettingPolicy *string `json:"forgetting_policy"`
// Temperature is the new temperature value (optional, range [0, 1])
Temperature *float64 `json:"temperature"`
// Avatar is the new avatar URL (optional)
Avatar *string `json:"avatar"`
// Description is the new description (optional)
Description *string `json:"description"`
// SystemPrompt is the new system prompt (optional)
SystemPrompt *string `json:"system_prompt"`
// UserPrompt is the new user prompt (optional)
UserPrompt *string `json:"user_prompt"`
}
// CreateMemoryResponse defines the response structure for memory operations
// Uses struct embedding to extend Memory struct with API-specific fields
type CreateMemoryResponse struct {
model.Memory
OwnerName *string `json:"owner_name,omitempty"`
MemoryType []string `json:"memory_type"`
}
// ListMemoryResponse defines the response structure for listing memories
type ListMemoryResponse struct {
// MemoryList is the array of memory objects
MemoryList []map[string]interface{} `json:"memory_list"`
// TotalCount is the total number of memories
TotalCount int64 `json:"total_count"`
}
// CreateMemory creates a new memory with the given parameters
// It validates the request, generates a unique name if needed, and creates the memory record
//
// Parameters:
// - tenantID: The tenant ID for which to create the memory
// - req: The memory creation request containing name, memory_type, embd_id, llm_id, etc.
//
// Returns:
// - *CreateMemoryResponse: The created memory details
// - error: Error if validation fails or creation fails
//
// Example:
//
// req := &CreateMemoryRequest{Name: "MyMemory", MemoryType: []string{"semantic"}, EmbdID: "embd1", LLMID: "llm1"}
// resp, err := service.CreateMemory("tenant123", req)
func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) (*CreateMemoryResponse, error) {
// Ensure tenant model IDs are populated for LLM and embedding model parameters
// This automatically fills tenant_llm_id and tenant_embd_id based on llm_id and embd_id
tenantLLMService := NewTenantLLMService()
params := map[string]interface{}{
"llm_id": req.LLMID,
"embd_id": req.EmbdID,
}
params = tenantLLMService.EnsureTenantModelIDForParams(tenantID, params)
// Update request with tenant model IDs from the processed params
if tenantLLMID, ok := params["tenant_llm_id"].(int64); ok {
tenantLLMIDStr := strconv.FormatInt(tenantLLMID, 10)
req.TenantLLMID = &tenantLLMIDStr
}
if tenantEmbdID, ok := params["tenant_embd_id"].(int64); ok {
tenantEmbdIDStr := strconv.FormatInt(tenantEmbdID, 10)
req.TenantEmbdID = &tenantEmbdIDStr
}
memoryName := strings.TrimSpace(req.Name)
if len(memoryName) == 0 {
return nil, errors.New("memory name cannot be empty or whitespace")
}
if len(memoryName) > MemoryNameLimit {
return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit)
}
if !isList(req.MemoryType) {
return nil, errors.New("memory type must be a list")
}
memoryTypeSet := make(map[string]bool)
for _, mt := range req.MemoryType {
lowerMT := strings.ToLower(mt)
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
}
memoryTypeSet[lowerMT] = true
}
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
for mt := range memoryTypeSet {
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
}
memoryName = duplicateName(func(name string, tid string) bool {
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
return len(existing) > 0
}, memoryName, tenantID)
if len(memoryName) > MemoryNameLimit {
return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit)
}
memoryTypeInt := dao.CalculateMemoryType(uniqueMemoryTypes)
timestamp := time.Now().UnixMilli()
systemPrompt := PromptAssembler{}.AssembleSystemPrompt(uniqueMemoryTypes)
newID := strings.ReplaceAll(uuid.New().String(), "-", "")
if len(newID) > 32 {
newID = newID[:32]
}
memory := &model.Memory{
ID: newID,
Name: memoryName,
TenantID: tenantID,
MemoryType: memoryTypeInt,
StorageType: "table",
EmbdID: req.EmbdID,
LLMID: req.LLMID,
Permissions: "me",
MemorySize: MemorySizeLimit,
ForgettingPolicy: string(ForgettingPolicyFIFO),
Temperature: 0.5,
SystemPrompt: &systemPrompt,
}
// Convert tenant model IDs from string to int64 for database
if req.TenantEmbdID != nil {
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
memory.TenantEmbdID = &embdID
}
}
if req.TenantLLMID != nil {
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
memory.TenantLLMID = &llmID
}
}
memory.CreateTime = &timestamp
memory.UpdateTime = &timestamp
if err := s.memoryDAO.Create(memory); err != nil {
return nil, errors.New("could not create new memory")
}
createdMemory, err := s.memoryDAO.GetByID(newID)
if err != nil {
return nil, errors.New("could not create new memory")
}
return formatRetDataFromMemory(createdMemory), nil
}
// UpdateMemory updates an existing memory with the provided fields
// Only the fields specified in the request will be updated (partial update)
//
// Parameters:
// - tenantID: The tenant ID for ownership verification
// - memoryID: The ID of the memory to update
// - req: The update request with optional fields to update
//
// Returns:
// - *CreateMemoryResponse: The updated memory details
// - error: Error if validation fails or update fails
//
// Example:
//
// req := &UpdateMemoryRequest{Name: ptr("NewName"), MemorySize: ptr(int64(1000000))}
// resp, err := service.UpdateMemory("tenant123", "memory456", req)
func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *UpdateMemoryRequest) (*CreateMemoryResponse, error) {
updateDict := make(map[string]interface{})
if req.Name != nil {
memoryName := strings.TrimSpace(*req.Name)
if len(memoryName) == 0 {
return nil, errors.New("memory name cannot be empty or whitespace")
}
if len(memoryName) > MemoryNameLimit {
return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit)
}
memoryName = duplicateName(func(name string, tid string) bool {
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
return len(existing) > 0
}, memoryName, tenantID)
if len(memoryName) > MemoryNameLimit {
return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit)
}
updateDict["name"] = memoryName
}
if req.Permissions != nil {
perm := TenantPermission(strings.ToLower(*req.Permissions))
if !validPermissions[perm] {
return nil, fmt.Errorf("unknown permission '%s'", *req.Permissions)
}
updateDict["permissions"] = perm
}
if req.LLMID != nil {
updateDict["llm_id"] = *req.LLMID
}
if req.EmbdID != nil {
updateDict["embd_id"] = *req.EmbdID
}
if req.TenantLLMID != nil {
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
updateDict["tenant_llm_id"] = llmID
}
}
if req.TenantEmbdID != nil {
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
updateDict["tenant_embd_id"] = embdID
}
}
if req.MemoryType != nil && len(req.MemoryType) > 0 {
memoryTypeSet := make(map[string]bool)
for _, mt := range req.MemoryType {
lowerMT := strings.ToLower(mt)
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
}
memoryTypeSet[lowerMT] = true
}
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
for mt := range memoryTypeSet {
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
}
updateDict["memory_type"] = uniqueMemoryTypes
}
if req.MemorySize != nil {
memorySize := *req.MemorySize
if !(memorySize > 0 && memorySize <= MemorySizeLimit) {
return nil, fmt.Errorf("memory size should be in range (0, %d] Bytes", MemorySizeLimit)
}
updateDict["memory_size"] = memorySize
}
if req.ForgettingPolicy != nil {
fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy))
if !validForgettingPolicies[fp] {
return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy)
}
updateDict["forgetting_policy"] = fp
}
if req.Temperature != nil {
temp := *req.Temperature
if !(temp >= 0 && temp <= 1) {
return nil, errors.New("temperature should be in range [0, 1]")
}
updateDict["temperature"] = temp
}
for _, field := range []string{"avatar", "description", "system_prompt", "user_prompt"} {
switch field {
case "avatar":
if req.Avatar != nil {
updateDict["avatar"] = *req.Avatar
}
case "description":
if req.Description != nil {
updateDict["description"] = *req.Description
}
case "system_prompt":
if req.SystemPrompt != nil {
updateDict["system_prompt"] = *req.SystemPrompt
}
case "user_prompt":
if req.UserPrompt != nil {
updateDict["user_prompt"] = *req.UserPrompt
}
}
}
currentMemory, err := s.memoryDAO.GetByID(memoryID)
if err != nil {
return nil, fmt.Errorf("memory '%s' not found", memoryID)
}
if len(updateDict) == 0 {
return formatRetDataFromMemory(currentMemory), nil
}
memorySize := currentMemory.MemorySize
notAllowedUpdate := []string{}
for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} {
if _, ok := updateDict[f]; ok && memorySize > 0 {
notAllowedUpdate = append(notAllowedUpdate, f)
}
}
if len(notAllowedUpdate) > 0 {
return nil, fmt.Errorf("can't update %v when memory isn't empty", notAllowedUpdate)
}
if _, ok := updateDict["memory_type"]; ok {
if _, ok := updateDict["system_prompt"]; !ok {
memoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
if len(memoryTypes) > 0 && currentMemory.SystemPrompt != nil {
defaultPrompt := PromptAssembler{}.AssembleSystemPrompt(memoryTypes)
if *currentMemory.SystemPrompt == defaultPrompt {
if types, ok := updateDict["memory_type"].([]string); ok {
updateDict["system_prompt"] = PromptAssembler{}.AssembleSystemPrompt(types)
}
}
}
}
}
if err := s.memoryDAO.UpdateByID(memoryID, updateDict); err != nil {
return nil, errors.New("failed to update memory")
}
updatedMemory, err := s.memoryDAO.GetByID(memoryID)
if err != nil {
return nil, errors.New("failed to get updated memory")
}
return formatRetDataFromMemory(updatedMemory), nil
}
// DeleteMemory deletes a memory by ID
// It also deletes associated message indexes before removing the memory record
//
// Parameters:
// - memoryID: The ID of the memory to delete
//
// Returns:
// - error: Error if memory not found or deletion fails
//
// Example:
//
// err := service.DeleteMemory("memory456")
func (s *MemoryService) DeleteMemory(memoryID string) error {
_, err := s.memoryDAO.GetByID(memoryID)
if err != nil {
return fmt.Errorf("memory '%s' not found", memoryID)
}
// TODO: Delete associated message index - Implementation pending MessageService
// messageService := NewMessageService()
// hasIndex, _ := messageService.HasIndex(memory.TenantID, memoryID)
// if hasIndex {
// messageService.DeleteMessage(nil, memory.TenantID, memoryID)
// }
// Delete memory record
if err := s.memoryDAO.DeleteByID(memoryID); err != nil {
return errors.New("failed to delete memory")
}
return nil
}
// ListMemories retrieves a paginated list of memories with optional filters
// When tenantIDs is empty, it retrieves all tenants associated with the user
//
// Parameters:
// - userID: The user ID for tenant filtering when tenantIDs is empty
// - tenantIDs: Array of tenant IDs to filter by (empty means all user's tenants)
// - memoryTypes: Array of memory type names to filter by (empty means all types)
// - storageType: Storage type to filter by (empty means all types)
// - keywords: Keywords to search in memory names (empty means no keyword filter)
// - page: Page number (1-based)
// - pageSize: Number of items per page
//
// Returns:
// - *ListMemoryResponse: Contains memory list and total count
// - error: Error if query fails
//
// Example:
//
// resp, err := service.ListMemories("user123", []string{}, []string{"semantic"}, "table", "test", 1, 10)
func (s *MemoryService) ListMemories(userID string, tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) (*ListMemoryResponse, error) {
// If tenantIDs is empty, get all tenants associated with the user
if len(tenantIDs) == 0 {
userTenantService := NewUserTenantService()
userTenants, err := userTenantService.GetUserTenantRelationByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
}
tenantIDs = make([]string, len(userTenants))
for i, tenant := range userTenants {
tenantIDs[i] = tenant.TenantID
}
}
memories, total, err := s.memoryDAO.GetByFilter(tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
if err != nil {
return nil, err
}
memoryList := make([]map[string]interface{}, 0, len(memories))
for _, m := range memories {
resp := formatRetDataFromMemoryListItem(m)
var createDateStr *string
if resp.CreateTime != nil {
createDateStr = formatDateToString(*resp.CreateTime)
}
memoryMap := map[string]interface{}{
"id": resp.ID,
"name": resp.Name,
"avatar": resp.Avatar,
"tenant_id": resp.TenantID,
"owner_name": resp.OwnerName,
"memory_type": resp.MemoryType,
"storage_type": resp.StorageType,
"permissions": resp.Permissions,
"description": resp.Description,
"create_time": resp.CreateTime,
"create_date": createDateStr,
}
memoryList = append(memoryList, memoryMap)
}
return &ListMemoryResponse{
MemoryList: memoryList,
TotalCount: total,
}, nil
}
// GetMemoryConfig retrieves the full configuration of a memory by ID
//
// Parameters:
// - memoryID: The ID of the memory to retrieve
//
// Returns:
// - *CreateMemoryResponse: The memory configuration details
// - error: Error if memory not found
//
// Example:
//
// resp, err := service.GetMemoryConfig("memory456")
func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, error) {
memory, err := s.memoryDAO.GetWithOwnerNameByID(memoryID)
if err != nil {
return nil, fmt.Errorf("memory '%s' not found", memoryID)
}
return formatRetDataFromMemoryListItem(memory), nil
}
// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService
// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... }
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
// TODO: AddMessage - Implementation pending - depends on embedding engine
// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... }
// TODO: ForgetMessage - Implementation pending - depends on embedding engine
// func (s *MemoryService) ForgetMessage(memoryID string, messageID int) (bool, error) { ... }
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
// TODO: SearchMessage - Implementation pending - depends on embedding engine
// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... }
// TODO: GetMessages - Implementation pending - depends on embedding engine
// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... }
// TODO: GetMessageContent - Implementation pending - depends on embedding engine
// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... }
// isList checks if a value is a list or array type
// This is a utility function for type validation
//
// Parameters:
// - v: The value to check
//
// Returns:
// - bool: true if v is []interface{} or []string, false otherwise
//
// Example:
//
// isList([]string{"a", "b"}) returns true
// isList("test") returns false
func isList(v interface{}) bool {
switch v.(type) {
case []interface{}, []string:
return true
default:
return false
}
}
// formatRetDataFromMemory converts a Memory model to CreateMemoryResponse format
// This is a utility function for formatting memory data for API responses
//
// Parameters:
// - memory: The Memory model to format
//
// Returns:
// - *CreateMemoryResponse: Formatted memory response with human-readable types and dates
//
// Example:
//
// resp := formatRetDataFromMemory(memoryModel)
func formatRetDataFromMemory(memory *model.Memory) *CreateMemoryResponse {
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
resp := &CreateMemoryResponse{
Memory: *memory,
OwnerName: nil,
MemoryType: memoryTypes,
}
return resp
}
func formatDateToString(t int64) *string {
if t == 0 {
return nil
}
// Database stores timestamps in milliseconds, convert to seconds
if t > 1e10 {
t = t / 1000
}
timeObj := time.Unix(t, 0)
s := timeObj.Format("2006-01-02 15:04:05")
return &s
}
// formatRetDataFromMemoryListItem converts a MemoryListItem to CreateMemoryResponse
// This function is used for both list and detail memory responses where owner_name is from JOIN query
//
// Parameters:
// - memory: MemoryListItem pointer with owner_name from JOIN
//
// Returns:
// - *CreateMemoryResponse: Formatted response with owner_name populated
//
// Example:
//
// resp := formatRetDataFromMemoryListItem(memoryItem)
func formatRetDataFromMemoryListItem(memory *model.MemoryListItem) *CreateMemoryResponse {
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
resp := &CreateMemoryResponse{
Memory: memory.Memory,
OwnerName: memory.OwnerName,
MemoryType: memoryTypes,
}
return resp
}

View File

@ -17,12 +17,14 @@
package service
import (
"strings"
"context"
"fmt"
"time"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/model"
"ragflow/internal/engine"
)
@ -92,6 +94,136 @@ type TenantListItem struct {
DeltaSeconds float64 `json:"delta_seconds"`
}
// TenantLLMService tenant LLM service
// This service handles operations related to tenant-specific LLM configurations
type TenantLLMService struct {
tenantLLMDAO *dao.TenantLLMDAO
}
// NewTenantLLMService creates a new TenantLLMService instance
func NewTenantLLMService() *TenantLLMService {
return &TenantLLMService{
tenantLLMDAO: dao.NewTenantLLMDAO(),
}
}
// GetAPIKey retrieves the tenant LLM record by tenant ID and model name
/**
* This method splits the model name into name and factory parts using the "@" separator,
* then queries the database for the matching tenant LLM configuration.
*
* Parameters:
* - tenantID: the unique identifier of the tenant
* - modelName: the model name, optionally including factory suffix (e.g., "gpt-4@OpenAI")
*
* Returns:
* - *model.TenantLLM: the tenant LLM record if found, nil otherwise
* - error: an error if the query fails, nil otherwise
*
* Example:
*
* service := NewTenantLLMService()
*
* // Get API key for model with factory
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4@OpenAI")
* if err != nil {
* log.Printf("Error: %v", err)
* }
*
* // Get API key for model without factory
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4")
*/
func (s *TenantLLMService) GetAPIKey(tenantID, modelName string) (*model.TenantLLM, error) {
modelName, factory := s.SplitModelNameAndFactory(modelName)
var tenantLLM *model.TenantLLM
var err error
if factory == "" {
tenantLLM, err = s.tenantLLMDAO.GetByTenantIDAndLLMName(tenantID, modelName)
} else {
tenantLLM, err = s.tenantLLMDAO.GetByTenantIDLLMNameAndFactory(tenantID, modelName, factory)
}
if err != nil {
return nil, err
}
return tenantLLM, nil
}
// SplitModelNameAndFactory splits a model name into name and factory parts
func (s *TenantLLMService) SplitModelNameAndFactory(modelName string) (string, string) {
arr := strings.Split(modelName, "@")
if len(arr) < 2 {
return modelName, ""
}
if len(arr) > 2 {
return strings.Join(arr[0:len(arr)-1], "@"), arr[len(arr)-1]
}
return arr[0], arr[1]
}
// EnsureTenantModelIDForParams ensures tenant model IDs are populated for LLM-related parameters
/**
* This method iterates through a predefined list of LLM-related parameter keys (llm_id, embd_id,
* asr_id, img2txt_id, rerank_id, tts_id) and automatically populates the corresponding tenant_*
* fields (tenant_llm_id, tenant_embd_id, etc.) with the tenant LLM record IDs.
*
* If a parameter key exists and its corresponding tenant_* key doesn't exist, this method will:
* 1. Query the tenant LLM record using GetAPIKey
* 2. If found, set the tenant_* key to the record's ID
* 3. If not found, set the tenant_* key to 0
*
* Parameters:
* - tenantID: the unique identifier of the tenant
* - params: a map of parameters to be updated (will be modified in place)
*
* Returns:
* - map[string]interface{}: the updated parameters map (same as input, modified in place)
*
* Example:
*
* service := NewTenantLLMService()
* params := map[string]interface{}{
* "llm_id": "gpt-4@OpenAI",
* "embd_id": "text-embedding-3-small@OpenAI",
* }
* result := service.EnsureTenantModelIDForParams("tenant-123", params)
* // result will contain:
* // {
* // "llm_id": "gpt-4@OpenAI",
* // "embd_id": "text-embedding-3-small@OpenAI",
* // "tenant_llm_id": 123, // ID from tenant_llm table
* // "tenant_embd_id": 456, // ID from tenant_llm table
* // }
*/
func (s *TenantLLMService) EnsureTenantModelIDForParams(tenantID string, params map[string]interface{}) map[string]interface{} {
paramKeys := []string{"llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"}
for _, key := range paramKeys {
tenantKey := "tenant_" + key
if value, exists := params[key]; exists && value != nil && value != "" {
if _, tenantExists := params[tenantKey]; !tenantExists {
modelName, ok := value.(string)
if !ok || modelName == "" {
continue
}
tenantLLM, err := s.GetAPIKey(tenantID, modelName)
if err == nil && tenantLLM != nil {
params[tenantKey] = tenantLLM.ID
} else {
params[tenantKey] = int64(0)
}
}
}
}
return params
}
// GetTenantList get tenant list for a user
func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) {
tenants, err := s.userTenantDAO.GetTenantsByUserID(userID)

View File

@ -924,6 +924,95 @@ func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) er
return nil
}
// UserTenantService user tenant service
// Provides business logic for user-tenant relationship management
type UserTenantService struct {
userTenantDAO *dao.UserTenantDAO
}
// NewUserTenantService creates a new UserTenantService instance
/**
* Returns:
* - *UserTenantService: a new UserTenantService instance
*
* Example:
*
* service := NewUserTenantService()
* relations, err := service.GetUserTenantRelationByUserID("user123")
*/
func NewUserTenantService() *UserTenantService {
return &UserTenantService{
userTenantDAO: dao.NewUserTenantDAO(),
}
}
// UserTenantRelation represents a user-tenant relationship response
// This structure matches the Python implementation's return format
type UserTenantRelation struct {
ID string `json:"id"`
UserID string `json:"user_id"`
TenantID string `json:"tenant_id"`
Role string `json:"role"`
}
// GetUserTenantRelationByUserID retrieves all user-tenant relationships for a given user ID
/**
* This method returns a list of user-tenant relationships with selected fields:
* - id: the relationship ID
* - user_id: the user ID
* - tenant_id: the tenant ID
* - role: the user's role in the tenant
*
* Parameters:
* - userID: the unique identifier of the user
*
* Returns:
* - []*UserTenantRelation: list of user-tenant relationships
* - error: error if the operation fails, nil otherwise
*
* Example:
*
* service := NewUserTenantService()
* relations, err := service.GetUserTenantRelationByUserID("user123")
* if err != nil {
* log.Printf("Failed to get user tenant relations: %v", err)
* return
* }
* for _, rel := range relations {
* fmt.Printf("User %s has role %s in tenant %s\n", rel.UserID, rel.Role, rel.TenantID)
* }
*/
func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*UserTenantRelation, error) {
relations, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, err
}
result := make([]*UserTenantRelation, len(relations))
for i, rel := range relations {
result[i] = convertToUserTenantRelation(rel)
}
return result, nil
}
// convertToUserTenantRelation converts model.UserTenant to UserTenantRelation
/**
* Parameters:
* - userTenant: the model.UserTenant to convert
*
* Returns:
* - *UserTenantRelation: the converted UserTenantRelation
*/
func convertToUserTenantRelation(userTenant *model.UserTenant) *UserTenantRelation {
return &UserTenantRelation{
ID: userTenant.ID,
UserID: userTenant.UserID,
TenantID: userTenant.TenantID,
Role: userTenant.Role,
}
}
// GetUserByAPIToken gets user by access key from Authorization header
// This is used for API token authentication
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
@ -963,4 +1052,5 @@ func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, comm
}
return user, common.CodeSuccess, nil
}

View File

@ -1 +1,2 @@
VITE_BASE_URL='/'
VITE_BASE_URL='/'
API_PROXY_SCHEME='python'

View File

@ -49,66 +49,18 @@ export default defineConfig(({ mode }) => {
},
},
hybrid: {
'/v1/system/token_list': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/v1/system/new_token': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/v1/system/token': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/v1/system/config': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/v1/user/login': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/v1/user/logout': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin/sandbox': {
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin/roles': {
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin/roles/owner/permission': {
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin/roles_with_permission': {
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin/whitelist': {
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin/variables': {
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'^(/api/v1/memories)|^(/v1/user/info)|^(/v1/user/tenant_info)|^(/v1/tenant/list)|^(/v1/system/config)|^(/v1/user/login)|^(/v1/user/logout)':
{
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'^(/api/v1/admin/sandbox)|^(/api/v1/admin/roles)|^(/api/v1/admin/roles/owner/permission)|^(/api/v1/admin/roles_with_permission)|^(/api/v1/admin/whitelist)|^(/api/v1/admin/variables)':
{
target: 'http://127.0.0.1:9381/',
changeOrigin: true,
ws: true,
},
'/api/v1/admin': {
target: 'http://127.0.0.1:9383/',
changeOrigin: true,