mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 02:07:49 +08:00
Feat: add memory function by go (#13754)
### What problem does this PR solve? Feat: Add Memory function by go ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@ -175,6 +175,7 @@ func startServer(config *server.Config) {
|
|||||||
connectorService := service.NewConnectorService()
|
connectorService := service.NewConnectorService()
|
||||||
searchService := service.NewSearchService()
|
searchService := service.NewSearchService()
|
||||||
fileService := service.NewFileService()
|
fileService := service.NewFileService()
|
||||||
|
memoryService := service.NewMemoryService()
|
||||||
|
|
||||||
// Initialize handler layer
|
// Initialize handler layer
|
||||||
authHandler := handler.NewAuthHandler()
|
authHandler := handler.NewAuthHandler()
|
||||||
@ -191,9 +192,10 @@ func startServer(config *server.Config) {
|
|||||||
connectorHandler := handler.NewConnectorHandler(connectorService, userService)
|
connectorHandler := handler.NewConnectorHandler(connectorService, userService)
|
||||||
searchHandler := handler.NewSearchHandler(searchService, userService)
|
searchHandler := handler.NewSearchHandler(searchService, userService)
|
||||||
fileHandler := handler.NewFileHandler(fileService, userService)
|
fileHandler := handler.NewFileHandler(fileService, userService)
|
||||||
|
memoryHandler := handler.NewMemoryHandler(memoryService)
|
||||||
|
|
||||||
// Initialize router
|
// Initialize router
|
||||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler)
|
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler)
|
||||||
|
|
||||||
// Create Gin engine
|
// Create Gin engine
|
||||||
ginEngine := gin.New()
|
ginEngine := gin.New()
|
||||||
|
|||||||
370
internal/dao/memory.go
Normal file
370
internal/dao/memory.go
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Package dao implements the data access layer
|
||||||
|
// This file implements Memory-related database operations
|
||||||
|
// Consistent with Python memory_service.py
|
||||||
|
package dao
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ragflow/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Memory type bit flag constants, consistent with Python MemoryType enum
|
||||||
|
const (
|
||||||
|
MemoryTypeRaw = 0b0001 // Raw memory (binary: 0001)
|
||||||
|
MemoryTypeSemantic = 0b0010 // Semantic memory (binary: 0010)
|
||||||
|
MemoryTypeEpisodic = 0b0100 // Episodic memory (binary: 0100)
|
||||||
|
MemoryTypeProcedural = 0b1000 // Procedural memory (binary: 1000)
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryTypeMap maps memory type names to bit flags
|
||||||
|
// Exported for use by service package
|
||||||
|
var MemoryTypeMap = map[string]int{
|
||||||
|
"raw": MemoryTypeRaw,
|
||||||
|
"semantic": MemoryTypeSemantic,
|
||||||
|
"episodic": MemoryTypeEpisodic,
|
||||||
|
"procedural": MemoryTypeProcedural,
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateMemoryType converts memory type names array to bit flags integer
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memoryTypeNames: Memory type names array
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - int64: Bit flags integer
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// CalculateMemoryType([]string{"raw", "semantic"}) returns 3 (0b0011)
|
||||||
|
func CalculateMemoryType(memoryTypeNames []string) int64 {
|
||||||
|
memoryType := 0
|
||||||
|
for _, name := range memoryTypeNames {
|
||||||
|
lowerName := strings.ToLower(name)
|
||||||
|
if mt, ok := MemoryTypeMap[lowerName]; ok {
|
||||||
|
memoryType |= mt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int64(memoryType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemoryTypeHuman converts memory type bit flags to human-readable names
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memoryType: Bit flags integer representing memory types
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []string: Array of human-readable memory type names
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// GetMemoryTypeHuman(3) returns ["raw", "semantic"]
|
||||||
|
func GetMemoryTypeHuman(memoryType int64) []string {
|
||||||
|
var result []string
|
||||||
|
if memoryType&int64(MemoryTypeRaw) != 0 {
|
||||||
|
result = append(result, "raw")
|
||||||
|
}
|
||||||
|
if memoryType&int64(MemoryTypeSemantic) != 0 {
|
||||||
|
result = append(result, "semantic")
|
||||||
|
}
|
||||||
|
if memoryType&int64(MemoryTypeEpisodic) != 0 {
|
||||||
|
result = append(result, "episodic")
|
||||||
|
}
|
||||||
|
if memoryType&int64(MemoryTypeProcedural) != 0 {
|
||||||
|
result = append(result, "procedural")
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryDAO handles all Memory-related database operations
|
||||||
|
type MemoryDAO struct{}
|
||||||
|
|
||||||
|
// NewMemoryDAO creates a new MemoryDAO instance
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *MemoryDAO: Initialized DAO instance
|
||||||
|
func NewMemoryDAO() *MemoryDAO {
|
||||||
|
return &MemoryDAO{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create inserts a new memory record into the database
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memory: Memory model pointer
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Database operation error
|
||||||
|
func (dao *MemoryDAO) Create(memory *model.Memory) error {
|
||||||
|
return DB.Create(memory).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID retrieves a memory record by ID from database
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - id: Memory ID
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *model.Memory: Memory model pointer
|
||||||
|
// - error: Database operation error
|
||||||
|
func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) {
|
||||||
|
var memory model.Memory
|
||||||
|
err := DB.Where("id = ?", id).First(&memory).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &memory, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByTenantID retrieves all memories for a tenant
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tenantID: Tenant ID
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*model.Memory: Memory model pointer array
|
||||||
|
// - error: Database operation error
|
||||||
|
func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) {
|
||||||
|
var memories []*model.Memory
|
||||||
|
err := DB.Where("tenant_id = ?", tenantID).Find(&memories).Error
|
||||||
|
return memories, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByNameAndTenant checks if memory exists by name and tenant ID
|
||||||
|
// Used for duplicate name deduplication
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - name: Memory name
|
||||||
|
// - tenantID: Tenant ID
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*model.Memory: Matching memory list (for existence check)
|
||||||
|
// - error: Database operation error
|
||||||
|
func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model.Memory, error) {
|
||||||
|
var memories []*model.Memory
|
||||||
|
err := DB.Where("name = ? AND tenant_id = ?", name, tenantID).Find(&memories).Error
|
||||||
|
return memories, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByIDs retrieves memories by multiple IDs
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ids: Memory ID list
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*model.Memory: Memory model pointer array
|
||||||
|
// - error: Database operation error
|
||||||
|
func (dao *MemoryDAO) GetByIDs(ids []string) ([]*model.Memory, error) {
|
||||||
|
var memories []*model.Memory
|
||||||
|
err := DB.Where("id IN ?", ids).Find(&memories).Error
|
||||||
|
return memories, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateByID updates a memory by ID
|
||||||
|
// Supports partial updates - only updates passed fields
|
||||||
|
// Automatically handles field type conversions
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - id: Memory ID
|
||||||
|
// - updates: Fields to update map
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Database operation error
|
||||||
|
//
|
||||||
|
// Field type handling:
|
||||||
|
// - memory_type: []string converts to bit flags integer
|
||||||
|
// - temperature: string converts to float64
|
||||||
|
// - name: Uses string value directly
|
||||||
|
// - permissions, forgetting_policy: Uses string value directly
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// updates := map[string]interface{}{"name": "NewName", "memory_type": []string{"semantic"}}
|
||||||
|
// err := dao.UpdateByID("memory123", updates)
|
||||||
|
func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||||
|
if updates == nil || len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range updates {
|
||||||
|
switch key {
|
||||||
|
case "memory_type":
|
||||||
|
if types, ok := value.([]string); ok {
|
||||||
|
updates[key] = CalculateMemoryType(types)
|
||||||
|
}
|
||||||
|
case "temperature":
|
||||||
|
if tempStr, ok := value.(string); ok {
|
||||||
|
var temp float64
|
||||||
|
fmt.Sscanf(tempStr, "%f", &temp)
|
||||||
|
updates[key] = temp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return DB.Model(&model.Memory{}).Where("id = ?", id).Updates(updates).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByID deletes a memory by ID
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - id: Memory ID
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Database operation error
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// err := dao.DeleteByID("memory123")
|
||||||
|
func (dao *MemoryDAO) DeleteByID(id string) error {
|
||||||
|
return DB.Where("id = ?", id).Delete(&model.Memory{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWithOwnerNameByID retrieves a memory with owner name by ID
|
||||||
|
// Joins with User table to get owner's nickname
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - id: Memory ID
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *model.MemoryListItem: Memory detail with owner name populated
|
||||||
|
// - error: Database operation error
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// memory, err := dao.GetWithOwnerNameByID("memory123")
|
||||||
|
func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, error) {
|
||||||
|
querySQL := `
|
||||||
|
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
|
||||||
|
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
|
||||||
|
m.permissions, m.description, m.memory_size, m.forgetting_policy,
|
||||||
|
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
|
||||||
|
m.update_time, m.update_date,
|
||||||
|
u.nickname as owner_name
|
||||||
|
FROM memory m
|
||||||
|
LEFT JOIN user u ON m.tenant_id = u.id
|
||||||
|
WHERE m.id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
var rawResult struct {
|
||||||
|
model.Memory
|
||||||
|
OwnerName *string `gorm:"column:owner_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Raw(querySQL, id).Scan(&rawResult).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.MemoryListItem{
|
||||||
|
Memory: rawResult.Memory,
|
||||||
|
OwnerName: rawResult.OwnerName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByFilter retrieves memories with optional filters
|
||||||
|
// Supports filtering by tenant_id, memory_type, storage_type, and keywords
|
||||||
|
// Returns paginated results with owner_name from user table JOIN
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tenantIDs: Array of tenant IDs to filter by (empty means all tenants)
|
||||||
|
// - memoryTypes: Array of memory type names to filter by (empty means all types)
|
||||||
|
// - storageType: Storage type to filter by (empty means all types)
|
||||||
|
// - keywords: Keywords to search in memory names (empty means no keyword filter)
|
||||||
|
// - page: Page number (1-based)
|
||||||
|
// - pageSize: Number of items per page
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*model.MemoryListItem: Memory list items with owner name populated
|
||||||
|
// - int64: Total count of matching memories
|
||||||
|
// - error: Database operation error
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10)
|
||||||
|
func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*model.MemoryListItem, int64, error) {
|
||||||
|
var conditions []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
if len(tenantIDs) > 0 {
|
||||||
|
conditions = append(conditions, "m.tenant_id IN ?")
|
||||||
|
args = append(args, tenantIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(memoryTypes) > 0 {
|
||||||
|
memoryTypeInt := CalculateMemoryType(memoryTypes)
|
||||||
|
conditions = append(conditions, "m.memory_type & ? > 0")
|
||||||
|
args = append(args, memoryTypeInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if storageType != "" {
|
||||||
|
conditions = append(conditions, "m.storage_type = ?")
|
||||||
|
args = append(args, storageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keywords != "" {
|
||||||
|
conditions = append(conditions, "m.name LIKE ?")
|
||||||
|
args = append(args, "%"+keywords+"%")
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := ""
|
||||||
|
if len(conditions) > 0 {
|
||||||
|
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM memory m %s", whereClause)
|
||||||
|
var total int64
|
||||||
|
if err := DB.Raw(countSQL, args...).Scan(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
querySQL := fmt.Sprintf(`
|
||||||
|
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
|
||||||
|
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
|
||||||
|
m.permissions, m.description, m.memory_size, m.forgetting_policy,
|
||||||
|
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
|
||||||
|
m.update_time, m.update_date,
|
||||||
|
u.nickname as owner_name
|
||||||
|
FROM memory m
|
||||||
|
LEFT JOIN user u ON m.tenant_id = u.id
|
||||||
|
%s
|
||||||
|
ORDER BY m.update_time DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
`, whereClause)
|
||||||
|
|
||||||
|
queryArgs := append(args, pageSize, offset)
|
||||||
|
|
||||||
|
var rawResults []struct {
|
||||||
|
model.Memory
|
||||||
|
OwnerName *string `gorm:"column:owner_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Raw(querySQL, queryArgs...).Scan(&rawResults).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memories := make([]*model.MemoryListItem, len(rawResults))
|
||||||
|
for i, r := range rawResults {
|
||||||
|
memories[i] = &model.MemoryListItem{
|
||||||
|
Memory: r.Memory,
|
||||||
|
OwnerName: r.OwnerName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return memories, total, nil
|
||||||
|
}
|
||||||
@ -141,3 +141,130 @@ func (dao *TenantLLMDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
|||||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.TenantLLM{})
|
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.TenantLLM{})
|
||||||
return result.RowsAffected, result.Error
|
return result.RowsAffected, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// splitModelNameAndFactory splits model name and factory from combined format
|
||||||
|
// This matches Python's split_model_name_and_factory logic
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - modelName: The model name which can be in format "ModelName" or "ModelName@Factory"
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The model name without factory prefix
|
||||||
|
// - string: The factory name (empty string if not specified)
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// modelName, factory := splitModelNameAndFactory("gpt-4")
|
||||||
|
// // Returns: "gpt-4", ""
|
||||||
|
//
|
||||||
|
// modelName, factory := splitModelNameAndFactory("gpt-4@OpenAI")
|
||||||
|
// // Returns: "gpt-4", "OpenAI"
|
||||||
|
func splitModelNameAndFactory(modelName string) (string, string) {
|
||||||
|
// Split by "@" separator
|
||||||
|
// Handle cases like "model@factory" or "model@sub@factory"
|
||||||
|
lastAtIndex := -1
|
||||||
|
for i := len(modelName) - 1; i >= 0; i-- {
|
||||||
|
if modelName[i] == '@' {
|
||||||
|
lastAtIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No "@" found, return original name
|
||||||
|
if lastAtIndex == -1 {
|
||||||
|
return modelName, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into model name and potential factory
|
||||||
|
modelNamePart := modelName[:lastAtIndex]
|
||||||
|
factory := modelName[lastAtIndex+1:]
|
||||||
|
|
||||||
|
// Validate if factory exists in llm_factories table
|
||||||
|
// This matches Python's logic of checking against model providers
|
||||||
|
var factoryCount int64
|
||||||
|
DB.Model(&model.LLMFactories{}).Where("name = ?", factory).Count(&factoryCount)
|
||||||
|
|
||||||
|
// If factory doesn't exist in database, treat the whole string as model name
|
||||||
|
if factoryCount == 0 {
|
||||||
|
return modelName, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelNamePart, factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByTenantIDAndLLMName gets tenant LLM by tenant ID and LLM name
|
||||||
|
// This is used to resolve tenant_llm_id from llm_id
|
||||||
|
// It supports both simple model names and factory-prefixed names (e.g., "gpt-4@OpenAI")
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tenantID: The tenant identifier
|
||||||
|
// - llmName: The LLM model name (can include factory prefix like "OpenAI@gpt-4")
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *model.TenantLLM: The tenant LLM record
|
||||||
|
// - error: Error if not found
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// // Simple model name
|
||||||
|
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4")
|
||||||
|
//
|
||||||
|
// // Model name with factory prefix
|
||||||
|
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4@OpenAI")
|
||||||
|
func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string) (*model.TenantLLM, error) {
|
||||||
|
var tenantLLM model.TenantLLM
|
||||||
|
|
||||||
|
// Split model name and factory from the combined format
|
||||||
|
modelName, factory := splitModelNameAndFactory(llmName)
|
||||||
|
|
||||||
|
// First attempt: try to find with model name only
|
||||||
|
err := DB.Where("tenant_id = ? AND llm_name = ?", tenantID, modelName).First(&tenantLLM).Error
|
||||||
|
if err == nil {
|
||||||
|
return &tenantLLM, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second attempt: if factory is specified, try with both model name and factory
|
||||||
|
if factory != "" {
|
||||||
|
err = DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, modelName, factory).First(&tenantLLM).Error
|
||||||
|
if err == nil {
|
||||||
|
return &tenantLLM, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for LocalAI and HuggingFace (matching Python logic)
|
||||||
|
// These factories append "___FactoryName" to the model name
|
||||||
|
if factory == "LocalAI" || factory == "HuggingFace" || factory == "OpenAI-API-Compatible" {
|
||||||
|
specialModelName := modelName + "___" + factory
|
||||||
|
err = DB.Where("tenant_id = ? AND llm_name = ?", tenantID, specialModelName).First(&tenantLLM).Error
|
||||||
|
if err == nil {
|
||||||
|
return &tenantLLM, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the last error (record not found)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByTenantIDLLMNameAndFactory gets tenant LLM by tenant ID, LLM name and factory
|
||||||
|
// This is used when model name includes factory suffix (e.g., "model@factory")
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tenantID: The tenant identifier
|
||||||
|
// - llmName: The LLM model name
|
||||||
|
// - factory: The LLM factory name
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *model.TenantLLM: The tenant LLM record
|
||||||
|
// - error: Error if not found
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tenantLLM, err := dao.GetByTenantIDLLMNameAndFactory("tenant123", "gpt-4", "OpenAI")
|
||||||
|
func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, factory string) (*model.TenantLLM, error) {
|
||||||
|
var tenantLLM model.TenantLLM
|
||||||
|
err := DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, llmName, factory).First(&tenantLLM).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tenantLLM, nil
|
||||||
|
}
|
||||||
|
|||||||
687
internal/handler/memory.go
Normal file
687
internal/handler/memory.go
Normal file
@ -0,0 +1,687 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Package handler contains all HTTP request handlers
|
||||||
|
// This file implements Memory-related API endpoint handlers
|
||||||
|
// Each method corresponds to an API endpoint in the Python memory_api.py
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"ragflow/internal/common"
|
||||||
|
"ragflow/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryHandler handles Memory-related HTTP requests
|
||||||
|
// Responsible for processing all Memory-related HTTP requests
|
||||||
|
// Each method corresponds to an API endpoint, implementing the same logic as Python memory_api.py
|
||||||
|
type MemoryHandler struct {
|
||||||
|
memoryService *service.MemoryService // Reference to Memory business service layer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryHandler creates a new MemoryHandler instance
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memoryService: Pointer to MemoryService business service layer
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *MemoryHandler: Initialized handler instance
|
||||||
|
func NewMemoryHandler(memoryService *service.MemoryService) *MemoryHandler {
|
||||||
|
return &MemoryHandler{
|
||||||
|
memoryService: memoryService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMemory handles POST request for creating Memory
|
||||||
|
// API Path: POST /api/v1/memories
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Creates a new memory record
|
||||||
|
// - Supports automatic system_prompt generation
|
||||||
|
// - Supports name deduplication (if name exists, adds sequence number)
|
||||||
|
//
|
||||||
|
// Request Parameters (JSON Body):
|
||||||
|
// - name (required): Memory name, max 128 characters
|
||||||
|
// - memory_type (required): Memory type array, supports ["raw", "semantic", "episodic", "procedural"]
|
||||||
|
// - embd_id (required): Embedding model ID
|
||||||
|
// - llm_id (required): LLM model ID
|
||||||
|
// - tenant_embd_id (optional): Tenant embedding model ID
|
||||||
|
// - tenant_llm_id (optional): Tenant LLM model ID
|
||||||
|
//
|
||||||
|
// Response Format:
|
||||||
|
// - code: Status code (0=success, other=error)
|
||||||
|
// - message: true on success, error message on failure
|
||||||
|
// - data: Memory object on success
|
||||||
|
//
|
||||||
|
// Business Logic (matching Python create_memory):
|
||||||
|
// 1. Validate user login status
|
||||||
|
// 2. Parse and validate request parameters
|
||||||
|
// 3. Call service layer to create memory
|
||||||
|
// 4. Return creation result
|
||||||
|
func (h *MemoryHandler) CreateMemory(c *gin.Context) {
|
||||||
|
// Check if API timing is enabled
|
||||||
|
// If RAGFLOW_API_TIMING environment variable is set, request processing time will be logged
|
||||||
|
timingEnabled := os.Getenv("RAGFLOW_API_TIMING")
|
||||||
|
var tStart time.Time
|
||||||
|
if timingEnabled != "" {
|
||||||
|
tStart = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current logged-in user information
|
||||||
|
// GetUser is a context value set by the authentication middleware
|
||||||
|
user, errorCode, errorMessage := GetUser(c)
|
||||||
|
if errorCode != common.CodeSuccess {
|
||||||
|
jsonError(c, errorCode, errorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := user.ID
|
||||||
|
|
||||||
|
// Parse JSON request body
|
||||||
|
var req service.CreateMemoryRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeBadRequest,
|
||||||
|
"message": err.Error(),
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required field: name
|
||||||
|
if req.Name == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "name is required",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required field: memory_type (must be non-empty array)
|
||||||
|
if len(req.MemoryType) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "memory_type is required and must be a list",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required field: embd_id
|
||||||
|
if req.EmbdID == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "embd_id is required",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required field: llm_id
|
||||||
|
if req.LLMID == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "llm_id is required",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record request parsing completion time (for timing)
|
||||||
|
tParsed := time.Now()
|
||||||
|
|
||||||
|
// Call service layer to create memory
|
||||||
|
result, err := h.memoryService.CreateMemory(userID, &req)
|
||||||
|
if err != nil {
|
||||||
|
// Log error if timing is enabled
|
||||||
|
if timingEnabled != "" {
|
||||||
|
totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0
|
||||||
|
parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0
|
||||||
|
_ = parseMs
|
||||||
|
_ = totalMs
|
||||||
|
}
|
||||||
|
|
||||||
|
errMsg := err.Error()
|
||||||
|
// Determine if it's an argument error and return appropriate error code
|
||||||
|
if isArgumentError(errMsg) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other errors return server error
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log success if timing is enabled
|
||||||
|
if timingEnabled != "" {
|
||||||
|
totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0
|
||||||
|
parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0
|
||||||
|
validateAndDbMs := totalMs - parseMs
|
||||||
|
_ = parseMs
|
||||||
|
_ = validateAndDbMs
|
||||||
|
_ = totalMs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success response
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeSuccess,
|
||||||
|
"message": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMemory handles PUT request for updating Memory
|
||||||
|
// API Path: PUT /api/v1/memories/:memory_id
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Updates configuration information for the specified memory
|
||||||
|
// - Supports partial updates: only update passed fields
|
||||||
|
//
|
||||||
|
// Request Parameters (JSON Body):
|
||||||
|
// - name (optional): Memory name
|
||||||
|
// - permissions (optional): Permission setting ["me", "team", "all"]
|
||||||
|
// - llm_id (optional): LLM model ID
|
||||||
|
// - embd_id (optional): Embedding model ID
|
||||||
|
// - tenant_llm_id (optional): Tenant LLM model ID
|
||||||
|
// - tenant_embd_id (optional): Tenant embedding model ID
|
||||||
|
// - memory_type (optional): Memory type array
|
||||||
|
// - memory_size (optional): Memory size, range (0, 5242880]
|
||||||
|
// - forgetting_policy (optional): Forgetting policy, default "FIFO"
|
||||||
|
// - temperature (optional): Temperature parameter, range [0, 1]
|
||||||
|
// - avatar (optional): Avatar URL
|
||||||
|
// - description (optional): Description
|
||||||
|
// - system_prompt (optional): System prompt
|
||||||
|
// - user_prompt (optional): User prompt
|
||||||
|
//
|
||||||
|
// Business Rules:
|
||||||
|
// - name length <= 128 characters
|
||||||
|
// - Cannot update tenant_embd_id, embd_id, memory_type when memory_size > 0
|
||||||
|
// - When updating memory_type, system_prompt is automatically regenerated if it's the default
|
||||||
|
func (h *MemoryHandler) UpdateMemory(c *gin.Context) {
|
||||||
|
// Get memory_id from URL path
|
||||||
|
memoryID := c.Param("memory_id")
|
||||||
|
if memoryID == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "memory_id is required",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current logged-in user information
|
||||||
|
user, errorCode, errorMessage := GetUser(c)
|
||||||
|
if errorCode != common.CodeSuccess {
|
||||||
|
jsonError(c, errorCode, errorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := user.ID
|
||||||
|
|
||||||
|
// Parse JSON request body
|
||||||
|
var req service.UpdateMemoryRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeBadRequest,
|
||||||
|
"message": err.Error(),
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call service layer to update memory
|
||||||
|
result, err := h.memoryService.UpdateMemory(userID, memoryID, &req)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := err.Error()
|
||||||
|
// Check if it's a "not found" error
|
||||||
|
if strings.Contains(errMsg, "not found") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeNotFound,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's an argument error
|
||||||
|
if isArgumentError(errMsg) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other errors return server error
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success response
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeSuccess,
|
||||||
|
"message": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMemory handles DELETE request for deleting Memory
|
||||||
|
// API Path: DELETE /api/v1/memories/:memory_id
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Deletes the specified memory record
|
||||||
|
// - Also deletes associated message data
|
||||||
|
//
|
||||||
|
// Business Logic:
|
||||||
|
// 1. Check if memory exists
|
||||||
|
// 2. Delete memory record
|
||||||
|
// 3. Delete associated message index
|
||||||
|
func (h *MemoryHandler) DeleteMemory(c *gin.Context) {
|
||||||
|
// Get memory_id from URL path
|
||||||
|
memoryID := c.Param("memory_id")
|
||||||
|
if memoryID == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "memory_id is required",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call service layer to delete memory
|
||||||
|
err := h.memoryService.DeleteMemory(memoryID)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := err.Error()
|
||||||
|
// Check if it's a "not found" error
|
||||||
|
if strings.Contains(errMsg, "not found") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeNotFound,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other errors return server error
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success response
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeSuccess,
|
||||||
|
"message": true,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMemories handles GET request for listing Memories
|
||||||
|
// API Path: GET /api/v1/memories
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Lists memories accessible to the current user
|
||||||
|
// - Supports multiple filter conditions
|
||||||
|
// - Supports pagination and keyword search
|
||||||
|
//
|
||||||
|
// Query Parameters:
|
||||||
|
// - memory_type (optional): Memory type filter, supports comma-separated multiple types
|
||||||
|
// - tenant_id (optional): Tenant ID filter
|
||||||
|
// - storage_type (optional): Storage type filter
|
||||||
|
// - keywords (optional): Keyword search (fuzzy match on name)
|
||||||
|
// - page (optional): Page number, default 1
|
||||||
|
// - page_size (optional): Items per page, default 50
|
||||||
|
//
|
||||||
|
// Response Format:
|
||||||
|
// - code: Status code
|
||||||
|
// - message: true
|
||||||
|
// - data.memory_list: Array of Memory objects
|
||||||
|
// - data.total_count: Total record count
|
||||||
|
func (h *MemoryHandler) ListMemories(c *gin.Context) {
|
||||||
|
// Get current logged-in user information
|
||||||
|
user, errorCode, errorMessage := GetUser(c)
|
||||||
|
if errorCode != common.CodeSuccess {
|
||||||
|
jsonError(c, errorCode, errorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse query parameters
|
||||||
|
memoryTypesParam := c.Query("memory_type")
|
||||||
|
tenantIDsParam := c.Query("tenant_id")
|
||||||
|
storageType := c.Query("storage_type")
|
||||||
|
keywords := c.Query("keywords")
|
||||||
|
pageStr := c.DefaultQuery("page", "1")
|
||||||
|
pageSizeStr := c.DefaultQuery("page_size", "50")
|
||||||
|
|
||||||
|
// Convert pagination parameters to integers
|
||||||
|
page, _ := strconv.Atoi(pageStr)
|
||||||
|
pageSize, _ := strconv.Atoi(pageSizeStr)
|
||||||
|
|
||||||
|
// Validate pagination parameters
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse memory_type parameter (supports comma separation)
|
||||||
|
var memoryTypes []string
|
||||||
|
if memoryTypesParam != "" {
|
||||||
|
if strings.Contains(memoryTypesParam, ",") {
|
||||||
|
memoryTypes = strings.Split(memoryTypesParam, ",")
|
||||||
|
} else {
|
||||||
|
memoryTypes = []string{memoryTypesParam}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse tenant_id parameter
|
||||||
|
// If not specified, service will get all tenants associated with the user
|
||||||
|
var tenantIDs []string
|
||||||
|
if tenantIDsParam != "" {
|
||||||
|
if strings.Contains(tenantIDsParam, ",") {
|
||||||
|
tenantIDs = strings.Split(tenantIDsParam, ",")
|
||||||
|
} else {
|
||||||
|
tenantIDs = []string{tenantIDsParam}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call service layer to get memory list
|
||||||
|
result, err := h.memoryService.ListMemories(user.ID, tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": err.Error(),
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success response
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeSuccess,
|
||||||
|
"message": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemoryConfig handles GET request for getting Memory configuration
|
||||||
|
// API Path: GET /api/v1/memories/:memory_id/config
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Gets complete configuration information for the specified memory
|
||||||
|
// - Includes owner name (obtained via JOIN with user table)
|
||||||
|
//
|
||||||
|
// Response Format:
|
||||||
|
// - code: Status code
|
||||||
|
// - message: true
|
||||||
|
// - data: Memory object, including owner_name field
|
||||||
|
func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) {
|
||||||
|
// Get memory_id from URL path
|
||||||
|
memoryID := c.Param("memory_id")
|
||||||
|
if memoryID == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeArgumentError,
|
||||||
|
"message": "memory_id is required",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call service layer to get memory configuration
|
||||||
|
result, err := h.memoryService.GetMemoryConfig(memoryID)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := err.Error()
|
||||||
|
// Check if it's a "not found" error
|
||||||
|
if strings.Contains(errMsg, "not found") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeNotFound,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other errors return server error
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": errMsg,
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success response
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeSuccess,
|
||||||
|
"message": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemoryMessages handles GET request for getting Memory messages
|
||||||
|
// API Path: GET /api/v1/memories/:memory_id
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Gets message list associated with the specified memory
|
||||||
|
// - Supports filtering by agent_id
|
||||||
|
// - Supports keyword search and pagination
|
||||||
|
//
|
||||||
|
// Query Parameters:
|
||||||
|
// - agent_id (optional): Agent ID filter, supports comma-separated multiple
|
||||||
|
// - keywords (optional): Keyword search
|
||||||
|
// - page (optional): Page number, default 1
|
||||||
|
// - page_size (optional): Items per page, default 50
|
||||||
|
//
|
||||||
|
// Response Format:
|
||||||
|
// - code: Status code
|
||||||
|
// - message: true
|
||||||
|
// - data.messages: Array of message objects
|
||||||
|
// - data.storage_type: Storage type
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on CanvasService and TaskService
|
||||||
|
func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMessage handles POST request for adding messages
|
||||||
|
// API Path: POST /api/v1/messages
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Adds messages to one or more memories
|
||||||
|
// - Messages will be embedded and saved to vector database
|
||||||
|
// - Creates asynchronous task for processing
|
||||||
|
//
|
||||||
|
// Request Parameters (JSON Body):
|
||||||
|
// - memory_id (required): Memory ID or ID array
|
||||||
|
// - agent_id (required): Agent ID
|
||||||
|
// - session_id (required): Session ID
|
||||||
|
// - user_input (required): User input
|
||||||
|
// - agent_response (required): Agent response
|
||||||
|
// - user_id (optional): User ID
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on embedding engine
|
||||||
|
func (h *MemoryHandler) AddMessage(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "AddMessage not implemented - pending embedding engine dependency",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForgetMessage handles DELETE request for forgetting messages
|
||||||
|
// API Path: DELETE /api/v1/messages/:memory_id/:message_id
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Soft-deletes the specified message (sets forget_at timestamp)
|
||||||
|
// - Message is not immediately deleted from database, but marked as "forgotten"
|
||||||
|
//
|
||||||
|
// Parameter Format:
|
||||||
|
// - memory_id: Memory ID
|
||||||
|
// - message_id: Message ID (integer)
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on embedding engine
|
||||||
|
func (h *MemoryHandler) ForgetMessage(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "ForgetMessage not implemented - pending embedding engine dependency",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMessage handles PUT request for updating message status
|
||||||
|
// API Path: PUT /api/v1/messages/:memory_id/:message_id
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Updates status of the specified message
|
||||||
|
// - status is a boolean, converted to integer for storage (true=1, false=0)
|
||||||
|
//
|
||||||
|
// Request Parameters (JSON Body):
|
||||||
|
// - status (required): Message status, boolean
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on embedding engine
|
||||||
|
func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "UpdateMessage not implemented - pending embedding engine dependency",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchMessage handles GET request for searching messages
|
||||||
|
// API Path: GET /api/v1/messages/search
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Searches messages across multiple memories
|
||||||
|
// - Supports vector similarity search and keyword search
|
||||||
|
// - Fuses results from both search methods
|
||||||
|
//
|
||||||
|
// Query Parameters:
|
||||||
|
// - memory_id (optional): Memory ID list, supports comma separation
|
||||||
|
// - query (optional): Search query text
|
||||||
|
// - similarity_threshold (optional): Similarity threshold, default 0.2
|
||||||
|
// - keywords_similarity_weight (optional): Keyword weight, default 0.7
|
||||||
|
// - top_n (optional): Number of results to return, default 5
|
||||||
|
// - agent_id (optional): Agent ID filter
|
||||||
|
// - session_id (optional): Session ID filter
|
||||||
|
// - user_id (optional): User ID filter
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on embedding engine
|
||||||
|
func (h *MemoryHandler) SearchMessage(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "SearchMessage not implemented - pending embedding engine dependency",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages handles GET request for getting message list
|
||||||
|
// API Path: GET /api/v1/messages
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Gets recent messages from specified memories
|
||||||
|
// - Supports filtering by agent_id and session_id
|
||||||
|
//
|
||||||
|
// Query Parameters:
|
||||||
|
// - memory_id (required): Memory ID list, supports comma separation
|
||||||
|
// - agent_id (optional): Agent ID filter
|
||||||
|
// - session_id (optional): Session ID filter
|
||||||
|
// - limit (optional): Number of results to return, default 10
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on embedding engine
|
||||||
|
func (h *MemoryHandler) GetMessages(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "GetMessages not implemented - pending embedding engine dependency",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessageContent handles GET request for getting message content
|
||||||
|
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Gets complete content of the specified message
|
||||||
|
// - doc_id format: memory_id + "_" + message_id
|
||||||
|
//
|
||||||
|
// Parameter Format:
|
||||||
|
// - memory_id: Memory ID
|
||||||
|
// - message_id: Message ID (integer)
|
||||||
|
//
|
||||||
|
// TODO: Implementation pending - depends on embedding engine
|
||||||
|
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": common.CodeServerError,
|
||||||
|
"message": "GetMessageContent not implemented - pending embedding engine dependency",
|
||||||
|
"data": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// isArgumentError determines if an error message is an argument error
|
||||||
|
//
|
||||||
|
// Function:
|
||||||
|
// - Checks if the error message contains any argument validation-related prefixes
|
||||||
|
// - Used to distinguish argument errors from server errors
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - msg: Error message string
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - bool: true if it's an argument error, false otherwise
|
||||||
|
func isArgumentError(msg string) bool {
|
||||||
|
// Define list of argument error prefixes
|
||||||
|
// Matches Python ArgumentException error messages
|
||||||
|
argumentErrorPrefixes := []string{
|
||||||
|
"memory name cannot be empty", // Memory name cannot be empty
|
||||||
|
"memory name exceeds limit", // Memory name exceeds limit
|
||||||
|
"memory type must be a list", // memory_type must be a list
|
||||||
|
"memory type is not supported", // Unsupported memory_type
|
||||||
|
}
|
||||||
|
// Check if error message starts with any prefix
|
||||||
|
for _, prefix := range argumentErrorPrefixes {
|
||||||
|
if len(msg) >= len(prefix) && msg[:len(prefix)] == prefix {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@ -42,3 +42,12 @@ type Memory struct {
|
|||||||
func (Memory) TableName() string {
|
func (Memory) TableName() string {
|
||||||
return "memory"
|
return "memory"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MemoryListItem represents a memory record with owner name from JOIN query.
|
||||||
|
// Uses struct embedding to extend Memory struct with owner_name from user table JOIN.
|
||||||
|
// Note: MemoryType is kept as int64 from Memory embedding; conversion to []string
|
||||||
|
// happens in the Service layer via CreateMemoryResponse.
|
||||||
|
type MemoryListItem struct {
|
||||||
|
Memory
|
||||||
|
OwnerName *string `json:"owner_name,omitempty"`
|
||||||
|
}
|
||||||
|
|||||||
@ -38,6 +38,7 @@ type Router struct {
|
|||||||
connectorHandler *handler.ConnectorHandler
|
connectorHandler *handler.ConnectorHandler
|
||||||
searchHandler *handler.SearchHandler
|
searchHandler *handler.SearchHandler
|
||||||
fileHandler *handler.FileHandler
|
fileHandler *handler.FileHandler
|
||||||
|
memoryHandler *handler.MemoryHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouter create router
|
// NewRouter create router
|
||||||
@ -56,6 +57,7 @@ func NewRouter(
|
|||||||
connectorHandler *handler.ConnectorHandler,
|
connectorHandler *handler.ConnectorHandler,
|
||||||
searchHandler *handler.SearchHandler,
|
searchHandler *handler.SearchHandler,
|
||||||
fileHandler *handler.FileHandler,
|
fileHandler *handler.FileHandler,
|
||||||
|
memoryHandler *handler.MemoryHandler,
|
||||||
) *Router {
|
) *Router {
|
||||||
return &Router{
|
return &Router{
|
||||||
authHandler: authHandler,
|
authHandler: authHandler,
|
||||||
@ -72,6 +74,7 @@ func NewRouter(
|
|||||||
connectorHandler: connectorHandler,
|
connectorHandler: connectorHandler,
|
||||||
searchHandler: searchHandler,
|
searchHandler: searchHandler,
|
||||||
fileHandler: fileHandler,
|
fileHandler: fileHandler,
|
||||||
|
memoryHandler: memoryHandler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,6 +166,28 @@ func (r *Router) Setup(engine *gin.Engine) {
|
|||||||
{
|
{
|
||||||
authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID)
|
authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Memory routes
|
||||||
|
memory := v1.Group("/memories")
|
||||||
|
{
|
||||||
|
memory.POST("", r.memoryHandler.CreateMemory)
|
||||||
|
memory.PUT("/:memory_id", r.memoryHandler.UpdateMemory)
|
||||||
|
memory.DELETE("/:memory_id", r.memoryHandler.DeleteMemory)
|
||||||
|
memory.GET("", r.memoryHandler.ListMemories)
|
||||||
|
memory.GET("/:memory_id/config", r.memoryHandler.GetMemoryConfig)
|
||||||
|
memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
|
||||||
|
// message := v1.Group("/messages")
|
||||||
|
// {
|
||||||
|
// message.POST("", r.memoryHandler.AddMessage)
|
||||||
|
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
|
||||||
|
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
|
||||||
|
// message.GET("/search", r.memoryHandler.SearchMessage)
|
||||||
|
// message.GET("", r.memoryHandler.GetMessages)
|
||||||
|
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
// Knowledge base routes
|
// Knowledge base routes
|
||||||
@ -260,6 +285,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
|||||||
file.GET("/parent_folder", r.fileHandler.GetParentFolder)
|
file.GET("/parent_folder", r.fileHandler.GetParentFolder)
|
||||||
file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders)
|
file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle undefined routes
|
// Handle undefined routes
|
||||||
|
|||||||
892
internal/service/memory.go
Normal file
892
internal/service/memory.go
Normal file
@ -0,0 +1,892 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"ragflow/internal/dao"
|
||||||
|
"ragflow/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MemoryNameLimit is the maximum length allowed for memory names
|
||||||
|
MemoryNameLimit = 128
|
||||||
|
// MemorySizeLimit is the maximum memory size in bytes (5MB)
|
||||||
|
MemorySizeLimit = 5242880
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note: MemoryType, MemoryTypeRaw, MemoryTypeSemantic, MemoryTypeEpisodic,
|
||||||
|
// MemoryTypeProcedural, and CalculateMemoryType are defined in the dao package
|
||||||
|
// and imported as dao.MemoryType, dao.MemoryTypeRaw, etc.
|
||||||
|
|
||||||
|
// TenantPermission defines the access permission levels for memory resources
|
||||||
|
// Note: This type is specific to the service layer
|
||||||
|
type TenantPermission string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TenantPermissionMe restricts access to the owner only
|
||||||
|
TenantPermissionMe TenantPermission = "me"
|
||||||
|
// TenantPermissionTeam allows access within the same team
|
||||||
|
TenantPermissionTeam TenantPermission = "team"
|
||||||
|
// TenantPermissionAll allows access to all tenants
|
||||||
|
TenantPermissionAll TenantPermission = "all"
|
||||||
|
)
|
||||||
|
|
||||||
|
// validPermissions defines which permission values are valid
|
||||||
|
var validPermissions = map[TenantPermission]bool{
|
||||||
|
TenantPermissionMe: true,
|
||||||
|
TenantPermissionTeam: true,
|
||||||
|
TenantPermissionAll: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForgettingPolicy defines the strategy for forgetting old memory entries
|
||||||
|
type ForgettingPolicy string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ForgettingPolicyFIFO uses First-In-First-Out strategy for forgetting
|
||||||
|
ForgettingPolicyFIFO ForgettingPolicy = "FIFO"
|
||||||
|
)
|
||||||
|
|
||||||
|
// validForgettingPolicies defines which forgetting policies are valid
|
||||||
|
var validForgettingPolicies = map[ForgettingPolicy]bool{
|
||||||
|
ForgettingPolicyFIFO: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Note: CalculateMemoryType and GetMemoryTypeHuman functions have been moved to dao package
|
||||||
|
// Use dao.CalculateMemoryType() and dao.GetMemoryTypeHuman() instead
|
||||||
|
|
||||||
|
// PromptAssembler handles the assembly of system prompts for memory extraction
|
||||||
|
type PromptAssembler struct{}
|
||||||
|
|
||||||
|
// SYSTEM_BASE_TEMPLATE is the base template for the system prompt used in memory extraction
|
||||||
|
// It includes placeholders for type-specific instructions, timestamp format, and max items
|
||||||
|
var SYSTEM_BASE_TEMPLATE = `**Memory Extraction Specialist**
|
||||||
|
You are an expert at analyzing conversations to extract structured memory.
|
||||||
|
|
||||||
|
{type_specific_instructions}
|
||||||
|
|
||||||
|
|
||||||
|
**OUTPUT REQUIREMENTS:**
|
||||||
|
1. Output MUST be valid JSON
|
||||||
|
2. Follow the specified output format exactly
|
||||||
|
3. Each extracted item MUST have: content, valid_at, invalid_at
|
||||||
|
4. Timestamps in {timestamp_format} format
|
||||||
|
5. Only extract memory types specified above
|
||||||
|
6. Maximum {max_items} items per type
|
||||||
|
`
|
||||||
|
|
||||||
|
// TYPE_INSTRUCTIONS contains specific instructions for each memory type extraction
|
||||||
|
var TYPE_INSTRUCTIONS = map[string]string{
|
||||||
|
"semantic": `
|
||||||
|
**EXTRACT SEMANTIC KNOWLEDGE:**
|
||||||
|
- Universal facts, definitions, concepts, relationships
|
||||||
|
- Time-invariant, generally true information
|
||||||
|
|
||||||
|
**Timestamp Rules:**
|
||||||
|
- valid_at: When the fact became true
|
||||||
|
- invalid_at: When it becomes false or empty if still true
|
||||||
|
`,
|
||||||
|
"episodic": `
|
||||||
|
**EXTRACT EPISODIC KNOWLEDGE:**
|
||||||
|
- Specific experiences, events, personal stories
|
||||||
|
- Time-bound, person-specific, contextual
|
||||||
|
|
||||||
|
**Timestamp Rules:**
|
||||||
|
- valid_at: Event start/occurrence time
|
||||||
|
- invalid_at: Event end time or empty if instantaneous
|
||||||
|
`,
|
||||||
|
"procedural": `
|
||||||
|
**EXTRACT PROCEDURAL KNOWLEDGE:**
|
||||||
|
- Processes, methods, step-by-step instructions
|
||||||
|
- Goal-oriented, actionable, often includes conditions
|
||||||
|
|
||||||
|
**Timestamp Rules:**
|
||||||
|
- valid_at: When procedure becomes valid/effective
|
||||||
|
- invalid_at: When it expires/becomes obsolete or empty if current
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// OUTPUT_TEMPLATES defines the output format for each memory type
|
||||||
|
var OUTPUT_TEMPLATES = map[string]string{
|
||||||
|
"semantic": `"semantic": [{"content": "Clear factual statement", "valid_at": "timestamp or empty", "invalid_at": "timestamp or empty"}]`,
|
||||||
|
"episodic": `"episodic": [{"content": "Narrative event description", "valid_at": "event start timestamp", "invalid_at": "event end timestamp or empty"}]`,
|
||||||
|
"procedural": `"procedural": [{"content": "Actionable instructions", "valid_at": "procedure effective timestamp", "invalid_at": "procedure expiration timestamp or empty"}]`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssembleSystemPrompt generates a complete system prompt for memory extraction
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memoryTypes: Array of memory type names to extract (e.g., ["semantic", "episodic"])
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: Complete system prompt with type-specific instructions and output format
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// AssembleSystemPrompt([]string{"semantic", "episodic"}) returns a prompt with instructions
|
||||||
|
// for both semantic and episodic memory extraction
|
||||||
|
func (PromptAssembler) AssembleSystemPrompt(memoryTypes []string) string {
|
||||||
|
typesToExtract := getTypesToExtract(memoryTypes)
|
||||||
|
if len(typesToExtract) == 0 {
|
||||||
|
typesToExtract = []string{"raw"}
|
||||||
|
}
|
||||||
|
|
||||||
|
typeInstructions := generateTypeInstructions(typesToExtract)
|
||||||
|
outputFormat := generateOutputFormat(typesToExtract)
|
||||||
|
|
||||||
|
fullPrompt := strings.Replace(SYSTEM_BASE_TEMPLATE, "{type_specific_instructions}", typeInstructions, 1)
|
||||||
|
fullPrompt = strings.Replace(fullPrompt, "{timestamp_format}", "ISO 8601", 1)
|
||||||
|
fullPrompt = strings.Replace(fullPrompt, "{max_items}", "5", 1)
|
||||||
|
|
||||||
|
fullPrompt += fmt.Sprintf("\n**REQUIRED OUTPUT FORMAT (JSON):\n```json\n{\n%s\n}\n```\n", outputFormat)
|
||||||
|
|
||||||
|
return fullPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTypesToExtract filters out "raw" type and returns valid memory types
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - requestedTypes: Array of requested memory type names
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []string: Filtered array of memory type names (excluding "raw")
|
||||||
|
func getTypesToExtract(requestedTypes []string) []string {
|
||||||
|
types := make(map[string]bool)
|
||||||
|
for _, rt := range requestedTypes {
|
||||||
|
lowerRT := strings.ToLower(rt)
|
||||||
|
if lowerRT != "raw" {
|
||||||
|
if _, ok := dao.MemoryTypeMap[lowerRT]; ok {
|
||||||
|
types[lowerRT] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(types))
|
||||||
|
for t := range types {
|
||||||
|
result = append(result, t)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTypeInstructions concatenates type-specific instructions
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - typesToExtract: Array of memory type names
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: Concatenated instructions for all specified types
|
||||||
|
func generateTypeInstructions(typesToExtract []string) string {
|
||||||
|
var instructions []string
|
||||||
|
for _, mt := range typesToExtract {
|
||||||
|
if instr, ok := TYPE_INSTRUCTIONS[mt]; ok {
|
||||||
|
instructions = append(instructions, instr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(instructions, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateOutputFormat concatenates output format templates
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - typesToExtract: Array of memory type names
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: Concatenated output format templates
|
||||||
|
func generateOutputFormat(typesToExtract []string) string {
|
||||||
|
var outputParts []string
|
||||||
|
for _, mt := range typesToExtract {
|
||||||
|
if tmpl, ok := OUTPUT_TEMPLATES[mt]; ok {
|
||||||
|
outputParts = append(outputParts, tmpl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(outputParts, ",\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryService handles business logic for memory operations
|
||||||
|
// It provides methods for creating, updating, deleting, and querying memories
|
||||||
|
type MemoryService struct {
|
||||||
|
memoryDAO *dao.MemoryDAO
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryService creates a new MemoryService instance
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *MemoryService: Initialized service instance with DAO
|
||||||
|
func NewMemoryService() *MemoryService {
|
||||||
|
return &MemoryService{
|
||||||
|
memoryDAO: dao.NewMemoryDAO(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitNameCounter splits a filename into base name and counter
|
||||||
|
// Handles names in format "filename(123)" pattern
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - filename: The filename to split
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The base name without counter
|
||||||
|
// - *int: The counter value, or nil if no counter exists
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// splitNameCounter("test(5)") returns ("test", 5)
|
||||||
|
// splitNameCounter("test") returns ("test", nil)
|
||||||
|
func splitNameCounter(filename string) (string, *int) {
|
||||||
|
re := regexp.MustCompile(`^(.+)\((\d+)\)$`)
|
||||||
|
matches := re.FindStringSubmatch(filename)
|
||||||
|
if len(matches) >= 3 {
|
||||||
|
counter := -1
|
||||||
|
fmt.Sscanf(matches[2], "%d", &counter)
|
||||||
|
stem := strings.TrimRight(matches[1], " ")
|
||||||
|
return stem, &counter
|
||||||
|
}
|
||||||
|
return filename, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// duplicateName generates a unique name by appending a counter if the name already exists
|
||||||
|
// It tries up to 1000 times to generate a unique name
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - queryFunc: Function to check if a name already exists (returns true if exists)
|
||||||
|
// - name: The original name
|
||||||
|
// - tenantID: The tenant ID for name uniqueness check
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: A unique name (either original or with counter appended)
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// duplicateName(func(name string, tid string) bool { return false }, "test", "tenant1") returns "test"
|
||||||
|
// duplicateName(func(name string, tid string) bool { return true }, "test", "tenant1") returns "test(1)"
|
||||||
|
func duplicateName(queryFunc func(name string, tenantID string) bool, name string, tenantID string) string {
|
||||||
|
const maxRetries = 1000
|
||||||
|
|
||||||
|
originalName := name
|
||||||
|
currentName := name
|
||||||
|
retries := 0
|
||||||
|
|
||||||
|
for retries < maxRetries {
|
||||||
|
if !queryFunc(currentName, tenantID) {
|
||||||
|
return currentName
|
||||||
|
}
|
||||||
|
|
||||||
|
stem, counter := splitNameCounter(currentName)
|
||||||
|
ext := path.Ext(stem)
|
||||||
|
stemBase := strings.TrimSuffix(stem, ext)
|
||||||
|
|
||||||
|
newCounter := 1
|
||||||
|
if counter != nil {
|
||||||
|
newCounter = *counter + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
currentName = fmt.Sprintf("%s(%d)%s", stemBase, newCounter, ext)
|
||||||
|
retries++
|
||||||
|
}
|
||||||
|
|
||||||
|
panic(fmt.Sprintf("Failed to generate unique name within %d attempts. Original: %s", maxRetries, originalName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMemoryRequest defines the request structure for creating a memory
|
||||||
|
type CreateMemoryRequest struct {
|
||||||
|
// Name is the memory name (required, max 128 characters)
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
// MemoryType is the array of memory type names (required)
|
||||||
|
MemoryType []string `json:"memory_type" binding:"required"`
|
||||||
|
// EmbdID is the embedding model ID (required)
|
||||||
|
EmbdID string `json:"embd_id" binding:"required"`
|
||||||
|
// LLMID is the language model ID (required)
|
||||||
|
LLMID string `json:"llm_id" binding:"required"`
|
||||||
|
// TenantEmbdID is the tenant-specific embedding model ID (optional)
|
||||||
|
TenantEmbdID *string `json:"tenant_embd_id"`
|
||||||
|
// TenantLLMID is the tenant-specific language model ID (optional)
|
||||||
|
TenantLLMID *string `json:"tenant_llm_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMemoryRequest defines the request structure for updating a memory
|
||||||
|
// All fields are optional, only provided fields will be updated
|
||||||
|
type UpdateMemoryRequest struct {
|
||||||
|
// Name is the new memory name (optional)
|
||||||
|
Name *string `json:"name"`
|
||||||
|
// Permissions is the new permission level (optional)
|
||||||
|
Permissions *string `json:"permissions"`
|
||||||
|
// LLMID is the new language model ID (optional)
|
||||||
|
LLMID *string `json:"llm_id"`
|
||||||
|
// EmbdID is the new embedding model ID (optional)
|
||||||
|
EmbdID *string `json:"embd_id"`
|
||||||
|
// TenantLLMID is the new tenant-specific language model ID (optional)
|
||||||
|
TenantLLMID *string `json:"tenant_llm_id"`
|
||||||
|
// TenantEmbdID is the new tenant-specific embedding model ID (optional)
|
||||||
|
TenantEmbdID *string `json:"tenant_embd_id"`
|
||||||
|
// MemoryType is the new array of memory type names (optional)
|
||||||
|
MemoryType []string `json:"memory_type"`
|
||||||
|
// MemorySize is the new memory size in bytes (optional, max 5MB)
|
||||||
|
MemorySize *int64 `json:"memory_size"`
|
||||||
|
// ForgettingPolicy is the new forgetting policy (optional)
|
||||||
|
ForgettingPolicy *string `json:"forgetting_policy"`
|
||||||
|
// Temperature is the new temperature value (optional, range [0, 1])
|
||||||
|
Temperature *float64 `json:"temperature"`
|
||||||
|
// Avatar is the new avatar URL (optional)
|
||||||
|
Avatar *string `json:"avatar"`
|
||||||
|
// Description is the new description (optional)
|
||||||
|
Description *string `json:"description"`
|
||||||
|
// SystemPrompt is the new system prompt (optional)
|
||||||
|
SystemPrompt *string `json:"system_prompt"`
|
||||||
|
// UserPrompt is the new user prompt (optional)
|
||||||
|
UserPrompt *string `json:"user_prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMemoryResponse defines the response structure for memory operations
|
||||||
|
// Uses struct embedding to extend Memory struct with API-specific fields
|
||||||
|
type CreateMemoryResponse struct {
|
||||||
|
model.Memory
|
||||||
|
OwnerName *string `json:"owner_name,omitempty"`
|
||||||
|
MemoryType []string `json:"memory_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMemoryResponse defines the response structure for listing memories
|
||||||
|
type ListMemoryResponse struct {
|
||||||
|
// MemoryList is the array of memory objects
|
||||||
|
MemoryList []map[string]interface{} `json:"memory_list"`
|
||||||
|
// TotalCount is the total number of memories
|
||||||
|
TotalCount int64 `json:"total_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMemory creates a new memory with the given parameters
|
||||||
|
// It validates the request, generates a unique name if needed, and creates the memory record
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tenantID: The tenant ID for which to create the memory
|
||||||
|
// - req: The memory creation request containing name, memory_type, embd_id, llm_id, etc.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *CreateMemoryResponse: The created memory details
|
||||||
|
// - error: Error if validation fails or creation fails
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// req := &CreateMemoryRequest{Name: "MyMemory", MemoryType: []string{"semantic"}, EmbdID: "embd1", LLMID: "llm1"}
|
||||||
|
// resp, err := service.CreateMemory("tenant123", req)
|
||||||
|
func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) (*CreateMemoryResponse, error) {
|
||||||
|
// Ensure tenant model IDs are populated for LLM and embedding model parameters
|
||||||
|
// This automatically fills tenant_llm_id and tenant_embd_id based on llm_id and embd_id
|
||||||
|
tenantLLMService := NewTenantLLMService()
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"llm_id": req.LLMID,
|
||||||
|
"embd_id": req.EmbdID,
|
||||||
|
}
|
||||||
|
params = tenantLLMService.EnsureTenantModelIDForParams(tenantID, params)
|
||||||
|
|
||||||
|
// Update request with tenant model IDs from the processed params
|
||||||
|
if tenantLLMID, ok := params["tenant_llm_id"].(int64); ok {
|
||||||
|
tenantLLMIDStr := strconv.FormatInt(tenantLLMID, 10)
|
||||||
|
req.TenantLLMID = &tenantLLMIDStr
|
||||||
|
}
|
||||||
|
if tenantEmbdID, ok := params["tenant_embd_id"].(int64); ok {
|
||||||
|
tenantEmbdIDStr := strconv.FormatInt(tenantEmbdID, 10)
|
||||||
|
req.TenantEmbdID = &tenantEmbdIDStr
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryName := strings.TrimSpace(req.Name)
|
||||||
|
if len(memoryName) == 0 {
|
||||||
|
return nil, errors.New("memory name cannot be empty or whitespace")
|
||||||
|
}
|
||||||
|
if len(memoryName) > MemoryNameLimit {
|
||||||
|
return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isList(req.MemoryType) {
|
||||||
|
return nil, errors.New("memory type must be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryTypeSet := make(map[string]bool)
|
||||||
|
for _, mt := range req.MemoryType {
|
||||||
|
lowerMT := strings.ToLower(mt)
|
||||||
|
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
|
||||||
|
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
|
||||||
|
}
|
||||||
|
memoryTypeSet[lowerMT] = true
|
||||||
|
}
|
||||||
|
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
|
||||||
|
for mt := range memoryTypeSet {
|
||||||
|
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryName = duplicateName(func(name string, tid string) bool {
|
||||||
|
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
|
||||||
|
return len(existing) > 0
|
||||||
|
}, memoryName, tenantID)
|
||||||
|
|
||||||
|
if len(memoryName) > MemoryNameLimit {
|
||||||
|
return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryTypeInt := dao.CalculateMemoryType(uniqueMemoryTypes)
|
||||||
|
timestamp := time.Now().UnixMilli()
|
||||||
|
|
||||||
|
systemPrompt := PromptAssembler{}.AssembleSystemPrompt(uniqueMemoryTypes)
|
||||||
|
|
||||||
|
newID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
if len(newID) > 32 {
|
||||||
|
newID = newID[:32]
|
||||||
|
}
|
||||||
|
|
||||||
|
memory := &model.Memory{
|
||||||
|
ID: newID,
|
||||||
|
Name: memoryName,
|
||||||
|
TenantID: tenantID,
|
||||||
|
MemoryType: memoryTypeInt,
|
||||||
|
StorageType: "table",
|
||||||
|
EmbdID: req.EmbdID,
|
||||||
|
LLMID: req.LLMID,
|
||||||
|
Permissions: "me",
|
||||||
|
MemorySize: MemorySizeLimit,
|
||||||
|
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
||||||
|
Temperature: 0.5,
|
||||||
|
SystemPrompt: &systemPrompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tenant model IDs from string to int64 for database
|
||||||
|
if req.TenantEmbdID != nil {
|
||||||
|
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
|
||||||
|
memory.TenantEmbdID = &embdID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.TenantLLMID != nil {
|
||||||
|
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
|
||||||
|
memory.TenantLLMID = &llmID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memory.CreateTime = ×tamp
|
||||||
|
memory.UpdateTime = ×tamp
|
||||||
|
|
||||||
|
if err := s.memoryDAO.Create(memory); err != nil {
|
||||||
|
return nil, errors.New("could not create new memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
createdMemory, err := s.memoryDAO.GetByID(newID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("could not create new memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatRetDataFromMemory(createdMemory), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMemory updates an existing memory with the provided fields
|
||||||
|
// Only the fields specified in the request will be updated (partial update)
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tenantID: The tenant ID for ownership verification
|
||||||
|
// - memoryID: The ID of the memory to update
|
||||||
|
// - req: The update request with optional fields to update
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *CreateMemoryResponse: The updated memory details
|
||||||
|
// - error: Error if validation fails or update fails
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// req := &UpdateMemoryRequest{Name: ptr("NewName"), MemorySize: ptr(int64(1000000))}
|
||||||
|
// resp, err := service.UpdateMemory("tenant123", "memory456", req)
|
||||||
|
func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *UpdateMemoryRequest) (*CreateMemoryResponse, error) {
|
||||||
|
updateDict := make(map[string]interface{})
|
||||||
|
|
||||||
|
if req.Name != nil {
|
||||||
|
memoryName := strings.TrimSpace(*req.Name)
|
||||||
|
if len(memoryName) == 0 {
|
||||||
|
return nil, errors.New("memory name cannot be empty or whitespace")
|
||||||
|
}
|
||||||
|
if len(memoryName) > MemoryNameLimit {
|
||||||
|
return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||||
|
}
|
||||||
|
memoryName = duplicateName(func(name string, tid string) bool {
|
||||||
|
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
|
||||||
|
return len(existing) > 0
|
||||||
|
}, memoryName, tenantID)
|
||||||
|
if len(memoryName) > MemoryNameLimit {
|
||||||
|
return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||||
|
}
|
||||||
|
updateDict["name"] = memoryName
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Permissions != nil {
|
||||||
|
perm := TenantPermission(strings.ToLower(*req.Permissions))
|
||||||
|
if !validPermissions[perm] {
|
||||||
|
return nil, fmt.Errorf("unknown permission '%s'", *req.Permissions)
|
||||||
|
}
|
||||||
|
updateDict["permissions"] = perm
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.LLMID != nil {
|
||||||
|
updateDict["llm_id"] = *req.LLMID
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.EmbdID != nil {
|
||||||
|
updateDict["embd_id"] = *req.EmbdID
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TenantLLMID != nil {
|
||||||
|
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
|
||||||
|
updateDict["tenant_llm_id"] = llmID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TenantEmbdID != nil {
|
||||||
|
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
|
||||||
|
updateDict["tenant_embd_id"] = embdID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MemoryType != nil && len(req.MemoryType) > 0 {
|
||||||
|
memoryTypeSet := make(map[string]bool)
|
||||||
|
for _, mt := range req.MemoryType {
|
||||||
|
lowerMT := strings.ToLower(mt)
|
||||||
|
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
|
||||||
|
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
|
||||||
|
}
|
||||||
|
memoryTypeSet[lowerMT] = true
|
||||||
|
}
|
||||||
|
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
|
||||||
|
for mt := range memoryTypeSet {
|
||||||
|
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
|
||||||
|
}
|
||||||
|
updateDict["memory_type"] = uniqueMemoryTypes
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MemorySize != nil {
|
||||||
|
memorySize := *req.MemorySize
|
||||||
|
if !(memorySize > 0 && memorySize <= MemorySizeLimit) {
|
||||||
|
return nil, fmt.Errorf("memory size should be in range (0, %d] Bytes", MemorySizeLimit)
|
||||||
|
}
|
||||||
|
updateDict["memory_size"] = memorySize
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ForgettingPolicy != nil {
|
||||||
|
fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy))
|
||||||
|
if !validForgettingPolicies[fp] {
|
||||||
|
return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy)
|
||||||
|
}
|
||||||
|
updateDict["forgetting_policy"] = fp
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Temperature != nil {
|
||||||
|
temp := *req.Temperature
|
||||||
|
if !(temp >= 0 && temp <= 1) {
|
||||||
|
return nil, errors.New("temperature should be in range [0, 1]")
|
||||||
|
}
|
||||||
|
updateDict["temperature"] = temp
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range []string{"avatar", "description", "system_prompt", "user_prompt"} {
|
||||||
|
switch field {
|
||||||
|
case "avatar":
|
||||||
|
if req.Avatar != nil {
|
||||||
|
updateDict["avatar"] = *req.Avatar
|
||||||
|
}
|
||||||
|
case "description":
|
||||||
|
if req.Description != nil {
|
||||||
|
updateDict["description"] = *req.Description
|
||||||
|
}
|
||||||
|
case "system_prompt":
|
||||||
|
if req.SystemPrompt != nil {
|
||||||
|
updateDict["system_prompt"] = *req.SystemPrompt
|
||||||
|
}
|
||||||
|
case "user_prompt":
|
||||||
|
if req.UserPrompt != nil {
|
||||||
|
updateDict["user_prompt"] = *req.UserPrompt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentMemory, err := s.memoryDAO.GetByID(memoryID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("memory '%s' not found", memoryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(updateDict) == 0 {
|
||||||
|
return formatRetDataFromMemory(currentMemory), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
memorySize := currentMemory.MemorySize
|
||||||
|
notAllowedUpdate := []string{}
|
||||||
|
for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} {
|
||||||
|
if _, ok := updateDict[f]; ok && memorySize > 0 {
|
||||||
|
notAllowedUpdate = append(notAllowedUpdate, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(notAllowedUpdate) > 0 {
|
||||||
|
return nil, fmt.Errorf("can't update %v when memory isn't empty", notAllowedUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := updateDict["memory_type"]; ok {
|
||||||
|
if _, ok := updateDict["system_prompt"]; !ok {
|
||||||
|
memoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
|
||||||
|
if len(memoryTypes) > 0 && currentMemory.SystemPrompt != nil {
|
||||||
|
defaultPrompt := PromptAssembler{}.AssembleSystemPrompt(memoryTypes)
|
||||||
|
if *currentMemory.SystemPrompt == defaultPrompt {
|
||||||
|
if types, ok := updateDict["memory_type"].([]string); ok {
|
||||||
|
updateDict["system_prompt"] = PromptAssembler{}.AssembleSystemPrompt(types)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.memoryDAO.UpdateByID(memoryID, updateDict); err != nil {
|
||||||
|
return nil, errors.New("failed to update memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedMemory, err := s.memoryDAO.GetByID(memoryID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("failed to get updated memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatRetDataFromMemory(updatedMemory), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMemory deletes a memory by ID
|
||||||
|
// It also deletes associated message indexes before removing the memory record
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memoryID: The ID of the memory to delete
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Error if memory not found or deletion fails
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// err := service.DeleteMemory("memory456")
|
||||||
|
func (s *MemoryService) DeleteMemory(memoryID string) error {
|
||||||
|
_, err := s.memoryDAO.GetByID(memoryID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("memory '%s' not found", memoryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Delete associated message index - Implementation pending MessageService
|
||||||
|
// messageService := NewMessageService()
|
||||||
|
// hasIndex, _ := messageService.HasIndex(memory.TenantID, memoryID)
|
||||||
|
// if hasIndex {
|
||||||
|
// messageService.DeleteMessage(nil, memory.TenantID, memoryID)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Delete memory record
|
||||||
|
if err := s.memoryDAO.DeleteByID(memoryID); err != nil {
|
||||||
|
return errors.New("failed to delete memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMemories retrieves a paginated list of memories with optional filters
|
||||||
|
// When tenantIDs is empty, it retrieves all tenants associated with the user
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - userID: The user ID for tenant filtering when tenantIDs is empty
|
||||||
|
// - tenantIDs: Array of tenant IDs to filter by (empty means all user's tenants)
|
||||||
|
// - memoryTypes: Array of memory type names to filter by (empty means all types)
|
||||||
|
// - storageType: Storage type to filter by (empty means all types)
|
||||||
|
// - keywords: Keywords to search in memory names (empty means no keyword filter)
|
||||||
|
// - page: Page number (1-based)
|
||||||
|
// - pageSize: Number of items per page
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *ListMemoryResponse: Contains memory list and total count
|
||||||
|
// - error: Error if query fails
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resp, err := service.ListMemories("user123", []string{}, []string{"semantic"}, "table", "test", 1, 10)
|
||||||
|
func (s *MemoryService) ListMemories(userID string, tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) (*ListMemoryResponse, error) {
|
||||||
|
// If tenantIDs is empty, get all tenants associated with the user
|
||||||
|
if len(tenantIDs) == 0 {
|
||||||
|
userTenantService := NewUserTenantService()
|
||||||
|
userTenants, err := userTenantService.GetUserTenantRelationByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get user tenants: %w", err)
|
||||||
|
}
|
||||||
|
tenantIDs = make([]string, len(userTenants))
|
||||||
|
for i, tenant := range userTenants {
|
||||||
|
tenantIDs[i] = tenant.TenantID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
memories, total, err := s.memoryDAO.GetByFilter(tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryList := make([]map[string]interface{}, 0, len(memories))
|
||||||
|
for _, m := range memories {
|
||||||
|
resp := formatRetDataFromMemoryListItem(m)
|
||||||
|
var createDateStr *string
|
||||||
|
if resp.CreateTime != nil {
|
||||||
|
createDateStr = formatDateToString(*resp.CreateTime)
|
||||||
|
}
|
||||||
|
memoryMap := map[string]interface{}{
|
||||||
|
"id": resp.ID,
|
||||||
|
"name": resp.Name,
|
||||||
|
"avatar": resp.Avatar,
|
||||||
|
"tenant_id": resp.TenantID,
|
||||||
|
"owner_name": resp.OwnerName,
|
||||||
|
"memory_type": resp.MemoryType,
|
||||||
|
"storage_type": resp.StorageType,
|
||||||
|
"permissions": resp.Permissions,
|
||||||
|
"description": resp.Description,
|
||||||
|
"create_time": resp.CreateTime,
|
||||||
|
"create_date": createDateStr,
|
||||||
|
}
|
||||||
|
memoryList = append(memoryList, memoryMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ListMemoryResponse{
|
||||||
|
MemoryList: memoryList,
|
||||||
|
TotalCount: total,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemoryConfig retrieves the full configuration of a memory by ID
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memoryID: The ID of the memory to retrieve
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *CreateMemoryResponse: The memory configuration details
|
||||||
|
// - error: Error if memory not found
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resp, err := service.GetMemoryConfig("memory456")
|
||||||
|
func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, error) {
|
||||||
|
memory, err := s.memoryDAO.GetWithOwnerNameByID(memoryID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("memory '%s' not found", memoryID)
|
||||||
|
}
|
||||||
|
return formatRetDataFromMemoryListItem(memory), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||||
|
// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... }
|
||||||
|
|
||||||
|
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||||
|
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
|
||||||
|
|
||||||
|
// TODO: AddMessage - Implementation pending - depends on embedding engine
|
||||||
|
// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... }
|
||||||
|
|
||||||
|
// TODO: ForgetMessage - Implementation pending - depends on embedding engine
|
||||||
|
// func (s *MemoryService) ForgetMessage(memoryID string, messageID int) (bool, error) { ... }
|
||||||
|
|
||||||
|
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
|
||||||
|
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
|
||||||
|
|
||||||
|
// TODO: SearchMessage - Implementation pending - depends on embedding engine
|
||||||
|
// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... }
|
||||||
|
|
||||||
|
// TODO: GetMessages - Implementation pending - depends on embedding engine
|
||||||
|
// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... }
|
||||||
|
|
||||||
|
// TODO: GetMessageContent - Implementation pending - depends on embedding engine
|
||||||
|
// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... }
|
||||||
|
|
||||||
|
// isList checks if a value is a list or array type
|
||||||
|
// This is a utility function for type validation
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - v: The value to check
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - bool: true if v is []interface{} or []string, false otherwise
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// isList([]string{"a", "b"}) returns true
|
||||||
|
// isList("test") returns false
|
||||||
|
func isList(v interface{}) bool {
|
||||||
|
switch v.(type) {
|
||||||
|
case []interface{}, []string:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatRetDataFromMemory converts a Memory model to CreateMemoryResponse format
|
||||||
|
// This is a utility function for formatting memory data for API responses
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memory: The Memory model to format
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *CreateMemoryResponse: Formatted memory response with human-readable types and dates
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resp := formatRetDataFromMemory(memoryModel)
|
||||||
|
func formatRetDataFromMemory(memory *model.Memory) *CreateMemoryResponse {
|
||||||
|
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
||||||
|
|
||||||
|
resp := &CreateMemoryResponse{
|
||||||
|
Memory: *memory,
|
||||||
|
OwnerName: nil,
|
||||||
|
MemoryType: memoryTypes,
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatDateToString(t int64) *string {
|
||||||
|
if t == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Database stores timestamps in milliseconds, convert to seconds
|
||||||
|
if t > 1e10 {
|
||||||
|
t = t / 1000
|
||||||
|
}
|
||||||
|
timeObj := time.Unix(t, 0)
|
||||||
|
s := timeObj.Format("2006-01-02 15:04:05")
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatRetDataFromMemoryListItem converts a MemoryListItem to CreateMemoryResponse
|
||||||
|
// This function is used for both list and detail memory responses where owner_name is from JOIN query
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - memory: MemoryListItem pointer with owner_name from JOIN
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *CreateMemoryResponse: Formatted response with owner_name populated
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resp := formatRetDataFromMemoryListItem(memoryItem)
|
||||||
|
func formatRetDataFromMemoryListItem(memory *model.MemoryListItem) *CreateMemoryResponse {
|
||||||
|
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
||||||
|
resp := &CreateMemoryResponse{
|
||||||
|
Memory: memory.Memory,
|
||||||
|
OwnerName: memory.OwnerName,
|
||||||
|
MemoryType: memoryTypes,
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
@ -17,12 +17,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"ragflow/internal/common"
|
"ragflow/internal/common"
|
||||||
"ragflow/internal/dao"
|
"ragflow/internal/dao"
|
||||||
|
"ragflow/internal/model"
|
||||||
"ragflow/internal/engine"
|
"ragflow/internal/engine"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -92,6 +94,136 @@ type TenantListItem struct {
|
|||||||
DeltaSeconds float64 `json:"delta_seconds"`
|
DeltaSeconds float64 `json:"delta_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TenantLLMService tenant LLM service
|
||||||
|
// This service handles operations related to tenant-specific LLM configurations
|
||||||
|
type TenantLLMService struct {
|
||||||
|
tenantLLMDAO *dao.TenantLLMDAO
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTenantLLMService creates a new TenantLLMService instance
|
||||||
|
func NewTenantLLMService() *TenantLLMService {
|
||||||
|
return &TenantLLMService{
|
||||||
|
tenantLLMDAO: dao.NewTenantLLMDAO(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAPIKey retrieves the tenant LLM record by tenant ID and model name
|
||||||
|
/**
|
||||||
|
* This method splits the model name into name and factory parts using the "@" separator,
|
||||||
|
* then queries the database for the matching tenant LLM configuration.
|
||||||
|
*
|
||||||
|
* Parameters:
|
||||||
|
* - tenantID: the unique identifier of the tenant
|
||||||
|
* - modelName: the model name, optionally including factory suffix (e.g., "gpt-4@OpenAI")
|
||||||
|
*
|
||||||
|
* Returns:
|
||||||
|
* - *model.TenantLLM: the tenant LLM record if found, nil otherwise
|
||||||
|
* - error: an error if the query fails, nil otherwise
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
*
|
||||||
|
* service := NewTenantLLMService()
|
||||||
|
*
|
||||||
|
* // Get API key for model with factory
|
||||||
|
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4@OpenAI")
|
||||||
|
* if err != nil {
|
||||||
|
* log.Printf("Error: %v", err)
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* // Get API key for model without factory
|
||||||
|
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4")
|
||||||
|
*/
|
||||||
|
func (s *TenantLLMService) GetAPIKey(tenantID, modelName string) (*model.TenantLLM, error) {
|
||||||
|
modelName, factory := s.SplitModelNameAndFactory(modelName)
|
||||||
|
|
||||||
|
var tenantLLM *model.TenantLLM
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if factory == "" {
|
||||||
|
tenantLLM, err = s.tenantLLMDAO.GetByTenantIDAndLLMName(tenantID, modelName)
|
||||||
|
} else {
|
||||||
|
tenantLLM, err = s.tenantLLMDAO.GetByTenantIDLLMNameAndFactory(tenantID, modelName, factory)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tenantLLM, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitModelNameAndFactory splits a model name into name and factory parts
|
||||||
|
func (s *TenantLLMService) SplitModelNameAndFactory(modelName string) (string, string) {
|
||||||
|
arr := strings.Split(modelName, "@")
|
||||||
|
if len(arr) < 2 {
|
||||||
|
return modelName, ""
|
||||||
|
}
|
||||||
|
if len(arr) > 2 {
|
||||||
|
return strings.Join(arr[0:len(arr)-1], "@"), arr[len(arr)-1]
|
||||||
|
}
|
||||||
|
return arr[0], arr[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureTenantModelIDForParams ensures tenant model IDs are populated for LLM-related parameters
|
||||||
|
/**
|
||||||
|
* This method iterates through a predefined list of LLM-related parameter keys (llm_id, embd_id,
|
||||||
|
* asr_id, img2txt_id, rerank_id, tts_id) and automatically populates the corresponding tenant_*
|
||||||
|
* fields (tenant_llm_id, tenant_embd_id, etc.) with the tenant LLM record IDs.
|
||||||
|
*
|
||||||
|
* If a parameter key exists and its corresponding tenant_* key doesn't exist, this method will:
|
||||||
|
* 1. Query the tenant LLM record using GetAPIKey
|
||||||
|
* 2. If found, set the tenant_* key to the record's ID
|
||||||
|
* 3. If not found, set the tenant_* key to 0
|
||||||
|
*
|
||||||
|
* Parameters:
|
||||||
|
* - tenantID: the unique identifier of the tenant
|
||||||
|
* - params: a map of parameters to be updated (will be modified in place)
|
||||||
|
*
|
||||||
|
* Returns:
|
||||||
|
* - map[string]interface{}: the updated parameters map (same as input, modified in place)
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
*
|
||||||
|
* service := NewTenantLLMService()
|
||||||
|
* params := map[string]interface{}{
|
||||||
|
* "llm_id": "gpt-4@OpenAI",
|
||||||
|
* "embd_id": "text-embedding-3-small@OpenAI",
|
||||||
|
* }
|
||||||
|
* result := service.EnsureTenantModelIDForParams("tenant-123", params)
|
||||||
|
* // result will contain:
|
||||||
|
* // {
|
||||||
|
* // "llm_id": "gpt-4@OpenAI",
|
||||||
|
* // "embd_id": "text-embedding-3-small@OpenAI",
|
||||||
|
* // "tenant_llm_id": 123, // ID from tenant_llm table
|
||||||
|
* // "tenant_embd_id": 456, // ID from tenant_llm table
|
||||||
|
* // }
|
||||||
|
*/
|
||||||
|
func (s *TenantLLMService) EnsureTenantModelIDForParams(tenantID string, params map[string]interface{}) map[string]interface{} {
|
||||||
|
paramKeys := []string{"llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"}
|
||||||
|
|
||||||
|
for _, key := range paramKeys {
|
||||||
|
tenantKey := "tenant_" + key
|
||||||
|
|
||||||
|
if value, exists := params[key]; exists && value != nil && value != "" {
|
||||||
|
if _, tenantExists := params[tenantKey]; !tenantExists {
|
||||||
|
modelName, ok := value.(string)
|
||||||
|
if !ok || modelName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tenantLLM, err := s.GetAPIKey(tenantID, modelName)
|
||||||
|
if err == nil && tenantLLM != nil {
|
||||||
|
params[tenantKey] = tenantLLM.ID
|
||||||
|
} else {
|
||||||
|
params[tenantKey] = int64(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
// GetTenantList get tenant list for a user
|
// GetTenantList get tenant list for a user
|
||||||
func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) {
|
func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) {
|
||||||
tenants, err := s.userTenantDAO.GetTenantsByUserID(userID)
|
tenants, err := s.userTenantDAO.GetTenantsByUserID(userID)
|
||||||
|
|||||||
@ -924,6 +924,95 @@ func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserTenantService user tenant service
|
||||||
|
// Provides business logic for user-tenant relationship management
|
||||||
|
type UserTenantService struct {
|
||||||
|
userTenantDAO *dao.UserTenantDAO
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserTenantService creates a new UserTenantService instance
|
||||||
|
/**
|
||||||
|
* Returns:
|
||||||
|
* - *UserTenantService: a new UserTenantService instance
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
*
|
||||||
|
* service := NewUserTenantService()
|
||||||
|
* relations, err := service.GetUserTenantRelationByUserID("user123")
|
||||||
|
*/
|
||||||
|
func NewUserTenantService() *UserTenantService {
|
||||||
|
return &UserTenantService{
|
||||||
|
userTenantDAO: dao.NewUserTenantDAO(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserTenantRelation represents a user-tenant relationship response
|
||||||
|
// This structure matches the Python implementation's return format
|
||||||
|
type UserTenantRelation struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
TenantID string `json:"tenant_id"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserTenantRelationByUserID retrieves all user-tenant relationships for a given user ID
|
||||||
|
/**
|
||||||
|
* This method returns a list of user-tenant relationships with selected fields:
|
||||||
|
* - id: the relationship ID
|
||||||
|
* - user_id: the user ID
|
||||||
|
* - tenant_id: the tenant ID
|
||||||
|
* - role: the user's role in the tenant
|
||||||
|
*
|
||||||
|
* Parameters:
|
||||||
|
* - userID: the unique identifier of the user
|
||||||
|
*
|
||||||
|
* Returns:
|
||||||
|
* - []*UserTenantRelation: list of user-tenant relationships
|
||||||
|
* - error: error if the operation fails, nil otherwise
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
*
|
||||||
|
* service := NewUserTenantService()
|
||||||
|
* relations, err := service.GetUserTenantRelationByUserID("user123")
|
||||||
|
* if err != nil {
|
||||||
|
* log.Printf("Failed to get user tenant relations: %v", err)
|
||||||
|
* return
|
||||||
|
* }
|
||||||
|
* for _, rel := range relations {
|
||||||
|
* fmt.Printf("User %s has role %s in tenant %s\n", rel.UserID, rel.Role, rel.TenantID)
|
||||||
|
* }
|
||||||
|
*/
|
||||||
|
func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*UserTenantRelation, error) {
|
||||||
|
relations, err := s.userTenantDAO.GetByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]*UserTenantRelation, len(relations))
|
||||||
|
for i, rel := range relations {
|
||||||
|
result[i] = convertToUserTenantRelation(rel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToUserTenantRelation converts model.UserTenant to UserTenantRelation
|
||||||
|
/**
|
||||||
|
* Parameters:
|
||||||
|
* - userTenant: the model.UserTenant to convert
|
||||||
|
*
|
||||||
|
* Returns:
|
||||||
|
* - *UserTenantRelation: the converted UserTenantRelation
|
||||||
|
*/
|
||||||
|
func convertToUserTenantRelation(userTenant *model.UserTenant) *UserTenantRelation {
|
||||||
|
return &UserTenantRelation{
|
||||||
|
ID: userTenant.ID,
|
||||||
|
UserID: userTenant.UserID,
|
||||||
|
TenantID: userTenant.TenantID,
|
||||||
|
Role: userTenant.Role,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetUserByAPIToken gets user by access key from Authorization header
|
// GetUserByAPIToken gets user by access key from Authorization header
|
||||||
// This is used for API token authentication
|
// This is used for API token authentication
|
||||||
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
|
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
|
||||||
@ -963,4 +1052,5 @@ func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, comm
|
|||||||
}
|
}
|
||||||
|
|
||||||
return user, common.CodeSuccess, nil
|
return user, common.CodeSuccess, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1 +1,2 @@
|
|||||||
VITE_BASE_URL='/'
|
VITE_BASE_URL='/'
|
||||||
|
API_PROXY_SCHEME='python'
|
||||||
@ -49,66 +49,18 @@ export default defineConfig(({ mode }) => {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
hybrid: {
|
hybrid: {
|
||||||
'/v1/system/token_list': {
|
'^(/api/v1/memories)|^(/v1/user/info)|^(/v1/user/tenant_info)|^(/v1/tenant/list)|^(/v1/system/config)|^(/v1/user/login)|^(/v1/user/logout)':
|
||||||
target: 'http://127.0.0.1:9384/',
|
{
|
||||||
changeOrigin: true,
|
target: 'http://127.0.0.1:9384/',
|
||||||
ws: true,
|
changeOrigin: true,
|
||||||
},
|
ws: true,
|
||||||
'/v1/system/new_token': {
|
},
|
||||||
target: 'http://127.0.0.1:9384/',
|
'^(/api/v1/admin/sandbox)|^(/api/v1/admin/roles)|^(/api/v1/admin/roles/owner/permission)|^(/api/v1/admin/roles_with_permission)|^(/api/v1/admin/whitelist)|^(/api/v1/admin/variables)':
|
||||||
changeOrigin: true,
|
{
|
||||||
ws: true,
|
target: 'http://127.0.0.1:9381/',
|
||||||
},
|
changeOrigin: true,
|
||||||
'/v1/system/token': {
|
ws: true,
|
||||||
target: 'http://127.0.0.1:9384/',
|
},
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/v1/system/config': {
|
|
||||||
target: 'http://127.0.0.1:9384/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/v1/user/login': {
|
|
||||||
target: 'http://127.0.0.1:9384/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/v1/user/logout': {
|
|
||||||
target: 'http://127.0.0.1:9384/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin/sandbox': {
|
|
||||||
target: 'http://127.0.0.1:9381/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin/roles': {
|
|
||||||
target: 'http://127.0.0.1:9381/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin/roles/owner/permission': {
|
|
||||||
target: 'http://127.0.0.1:9381/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin/roles_with_permission': {
|
|
||||||
target: 'http://127.0.0.1:9381/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin/whitelist': {
|
|
||||||
target: 'http://127.0.0.1:9381/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin/variables': {
|
|
||||||
target: 'http://127.0.0.1:9381/',
|
|
||||||
changeOrigin: true,
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
'/api/v1/admin': {
|
'/api/v1/admin': {
|
||||||
target: 'http://127.0.0.1:9383/',
|
target: 'http://127.0.0.1:9383/',
|
||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
|
|||||||
Reference in New Issue
Block a user