mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-30 18:59:59 +08:00
Feat: add memory function by go (#13754)
### What problem does this PR solve? Feat: Add Memory function by go ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@ -175,6 +175,7 @@ func startServer(config *server.Config) {
|
||||
connectorService := service.NewConnectorService()
|
||||
searchService := service.NewSearchService()
|
||||
fileService := service.NewFileService()
|
||||
memoryService := service.NewMemoryService()
|
||||
|
||||
// Initialize handler layer
|
||||
authHandler := handler.NewAuthHandler()
|
||||
@ -191,9 +192,10 @@ func startServer(config *server.Config) {
|
||||
connectorHandler := handler.NewConnectorHandler(connectorService, userService)
|
||||
searchHandler := handler.NewSearchHandler(searchService, userService)
|
||||
fileHandler := handler.NewFileHandler(fileService, userService)
|
||||
memoryHandler := handler.NewMemoryHandler(memoryService)
|
||||
|
||||
// Initialize router
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler)
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler)
|
||||
|
||||
// Create Gin engine
|
||||
ginEngine := gin.New()
|
||||
|
||||
370
internal/dao/memory.go
Normal file
370
internal/dao/memory.go
Normal file
@ -0,0 +1,370 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package dao implements the data access layer
|
||||
// This file implements Memory-related database operations
|
||||
// Consistent with Python memory_service.py
|
||||
package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
// Memory type bit flag constants, consistent with Python MemoryType enum
|
||||
const (
|
||||
MemoryTypeRaw = 0b0001 // Raw memory (binary: 0001)
|
||||
MemoryTypeSemantic = 0b0010 // Semantic memory (binary: 0010)
|
||||
MemoryTypeEpisodic = 0b0100 // Episodic memory (binary: 0100)
|
||||
MemoryTypeProcedural = 0b1000 // Procedural memory (binary: 1000)
|
||||
)
|
||||
|
||||
// MemoryTypeMap maps memory type names to bit flags
|
||||
// Exported for use by service package
|
||||
var MemoryTypeMap = map[string]int{
|
||||
"raw": MemoryTypeRaw,
|
||||
"semantic": MemoryTypeSemantic,
|
||||
"episodic": MemoryTypeEpisodic,
|
||||
"procedural": MemoryTypeProcedural,
|
||||
}
|
||||
|
||||
// CalculateMemoryType converts memory type names array to bit flags integer
|
||||
//
|
||||
// Parameters:
|
||||
// - memoryTypeNames: Memory type names array
|
||||
//
|
||||
// Returns:
|
||||
// - int64: Bit flags integer
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// CalculateMemoryType([]string{"raw", "semantic"}) returns 3 (0b0011)
|
||||
func CalculateMemoryType(memoryTypeNames []string) int64 {
|
||||
memoryType := 0
|
||||
for _, name := range memoryTypeNames {
|
||||
lowerName := strings.ToLower(name)
|
||||
if mt, ok := MemoryTypeMap[lowerName]; ok {
|
||||
memoryType |= mt
|
||||
}
|
||||
}
|
||||
return int64(memoryType)
|
||||
}
|
||||
|
||||
// GetMemoryTypeHuman converts memory type bit flags to human-readable names
|
||||
//
|
||||
// Parameters:
|
||||
// - memoryType: Bit flags integer representing memory types
|
||||
//
|
||||
// Returns:
|
||||
// - []string: Array of human-readable memory type names
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// GetMemoryTypeHuman(3) returns ["raw", "semantic"]
|
||||
func GetMemoryTypeHuman(memoryType int64) []string {
|
||||
var result []string
|
||||
if memoryType&int64(MemoryTypeRaw) != 0 {
|
||||
result = append(result, "raw")
|
||||
}
|
||||
if memoryType&int64(MemoryTypeSemantic) != 0 {
|
||||
result = append(result, "semantic")
|
||||
}
|
||||
if memoryType&int64(MemoryTypeEpisodic) != 0 {
|
||||
result = append(result, "episodic")
|
||||
}
|
||||
if memoryType&int64(MemoryTypeProcedural) != 0 {
|
||||
result = append(result, "procedural")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MemoryDAO handles all Memory-related database operations
|
||||
type MemoryDAO struct{}
|
||||
|
||||
// NewMemoryDAO creates a new MemoryDAO instance
|
||||
//
|
||||
// Returns:
|
||||
// - *MemoryDAO: Initialized DAO instance
|
||||
func NewMemoryDAO() *MemoryDAO {
|
||||
return &MemoryDAO{}
|
||||
}
|
||||
|
||||
// Create inserts a new memory record into the database
|
||||
//
|
||||
// Parameters:
|
||||
// - memory: Memory model pointer
|
||||
//
|
||||
// Returns:
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) Create(memory *model.Memory) error {
|
||||
return DB.Create(memory).Error
|
||||
}
|
||||
|
||||
// GetByID retrieves a memory record by ID from database
|
||||
//
|
||||
// Parameters:
|
||||
// - id: Memory ID
|
||||
//
|
||||
// Returns:
|
||||
// - *model.Memory: Memory model pointer
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) {
|
||||
var memory model.Memory
|
||||
err := DB.Where("id = ?", id).First(&memory).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &memory, nil
|
||||
}
|
||||
|
||||
// GetByTenantID retrieves all memories for a tenant
|
||||
//
|
||||
// Parameters:
|
||||
// - tenantID: Tenant ID
|
||||
//
|
||||
// Returns:
|
||||
// - []*model.Memory: Memory model pointer array
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) {
|
||||
var memories []*model.Memory
|
||||
err := DB.Where("tenant_id = ?", tenantID).Find(&memories).Error
|
||||
return memories, err
|
||||
}
|
||||
|
||||
// GetByNameAndTenant checks if memory exists by name and tenant ID
|
||||
// Used for duplicate name deduplication
|
||||
//
|
||||
// Parameters:
|
||||
// - name: Memory name
|
||||
// - tenantID: Tenant ID
|
||||
//
|
||||
// Returns:
|
||||
// - []*model.Memory: Matching memory list (for existence check)
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model.Memory, error) {
|
||||
var memories []*model.Memory
|
||||
err := DB.Where("name = ? AND tenant_id = ?", name, tenantID).Find(&memories).Error
|
||||
return memories, err
|
||||
}
|
||||
|
||||
// GetByIDs retrieves memories by multiple IDs
|
||||
//
|
||||
// Parameters:
|
||||
// - ids: Memory ID list
|
||||
//
|
||||
// Returns:
|
||||
// - []*model.Memory: Memory model pointer array
|
||||
// - error: Database operation error
|
||||
func (dao *MemoryDAO) GetByIDs(ids []string) ([]*model.Memory, error) {
|
||||
var memories []*model.Memory
|
||||
err := DB.Where("id IN ?", ids).Find(&memories).Error
|
||||
return memories, err
|
||||
}
|
||||
|
||||
// UpdateByID updates a memory by ID
|
||||
// Supports partial updates - only updates passed fields
|
||||
// Automatically handles field type conversions
|
||||
//
|
||||
// Parameters:
|
||||
// - id: Memory ID
|
||||
// - updates: Fields to update map
|
||||
//
|
||||
// Returns:
|
||||
// - error: Database operation error
|
||||
//
|
||||
// Field type handling:
|
||||
// - memory_type: []string converts to bit flags integer
|
||||
// - temperature: string converts to float64
|
||||
// - name: Uses string value directly
|
||||
// - permissions, forgetting_policy: Uses string value directly
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// updates := map[string]interface{}{"name": "NewName", "memory_type": []string{"semantic"}}
|
||||
// err := dao.UpdateByID("memory123", updates)
|
||||
func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
if updates == nil || len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for key, value := range updates {
|
||||
switch key {
|
||||
case "memory_type":
|
||||
if types, ok := value.([]string); ok {
|
||||
updates[key] = CalculateMemoryType(types)
|
||||
}
|
||||
case "temperature":
|
||||
if tempStr, ok := value.(string); ok {
|
||||
var temp float64
|
||||
fmt.Sscanf(tempStr, "%f", &temp)
|
||||
updates[key] = temp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return DB.Model(&model.Memory{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteByID deletes a memory by ID
|
||||
//
|
||||
// Parameters:
|
||||
// - id: Memory ID
|
||||
//
|
||||
// Returns:
|
||||
// - error: Database operation error
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := dao.DeleteByID("memory123")
|
||||
func (dao *MemoryDAO) DeleteByID(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&model.Memory{}).Error
|
||||
}
|
||||
|
||||
// GetWithOwnerNameByID retrieves a memory with owner name by ID
|
||||
// Joins with User table to get owner's nickname
|
||||
//
|
||||
// Parameters:
|
||||
// - id: Memory ID
|
||||
//
|
||||
// Returns:
|
||||
// - *model.MemoryListItem: Memory detail with owner name populated
|
||||
// - error: Database operation error
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// memory, err := dao.GetWithOwnerNameByID("memory123")
|
||||
func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, error) {
|
||||
querySQL := `
|
||||
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
|
||||
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
|
||||
m.permissions, m.description, m.memory_size, m.forgetting_policy,
|
||||
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
|
||||
m.update_time, m.update_date,
|
||||
u.nickname as owner_name
|
||||
FROM memory m
|
||||
LEFT JOIN user u ON m.tenant_id = u.id
|
||||
WHERE m.id = ?
|
||||
`
|
||||
|
||||
var rawResult struct {
|
||||
model.Memory
|
||||
OwnerName *string `gorm:"column:owner_name"`
|
||||
}
|
||||
|
||||
if err := DB.Raw(querySQL, id).Scan(&rawResult).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.MemoryListItem{
|
||||
Memory: rawResult.Memory,
|
||||
OwnerName: rawResult.OwnerName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetByFilter retrieves memories with optional filters
|
||||
// Supports filtering by tenant_id, memory_type, storage_type, and keywords
|
||||
// Returns paginated results with owner_name from user table JOIN
|
||||
//
|
||||
// Parameters:
|
||||
// - tenantIDs: Array of tenant IDs to filter by (empty means all tenants)
|
||||
// - memoryTypes: Array of memory type names to filter by (empty means all types)
|
||||
// - storageType: Storage type to filter by (empty means all types)
|
||||
// - keywords: Keywords to search in memory names (empty means no keyword filter)
|
||||
// - page: Page number (1-based)
|
||||
// - pageSize: Number of items per page
|
||||
//
|
||||
// Returns:
|
||||
// - []*model.MemoryListItem: Memory list items with owner name populated
|
||||
// - int64: Total count of matching memories
|
||||
// - error: Database operation error
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10)
|
||||
func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*model.MemoryListItem, int64, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
conditions = append(conditions, "m.tenant_id IN ?")
|
||||
args = append(args, tenantIDs)
|
||||
}
|
||||
|
||||
if len(memoryTypes) > 0 {
|
||||
memoryTypeInt := CalculateMemoryType(memoryTypes)
|
||||
conditions = append(conditions, "m.memory_type & ? > 0")
|
||||
args = append(args, memoryTypeInt)
|
||||
}
|
||||
|
||||
if storageType != "" {
|
||||
conditions = append(conditions, "m.storage_type = ?")
|
||||
args = append(args, storageType)
|
||||
}
|
||||
|
||||
if keywords != "" {
|
||||
conditions = append(conditions, "m.name LIKE ?")
|
||||
args = append(args, "%"+keywords+"%")
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM memory m %s", whereClause)
|
||||
var total int64
|
||||
if err := DB.Raw(countSQL, args...).Scan(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
querySQL := fmt.Sprintf(`
|
||||
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
|
||||
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
|
||||
m.permissions, m.description, m.memory_size, m.forgetting_policy,
|
||||
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
|
||||
m.update_time, m.update_date,
|
||||
u.nickname as owner_name
|
||||
FROM memory m
|
||||
LEFT JOIN user u ON m.tenant_id = u.id
|
||||
%s
|
||||
ORDER BY m.update_time DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`, whereClause)
|
||||
|
||||
queryArgs := append(args, pageSize, offset)
|
||||
|
||||
var rawResults []struct {
|
||||
model.Memory
|
||||
OwnerName *string `gorm:"column:owner_name"`
|
||||
}
|
||||
|
||||
if err := DB.Raw(querySQL, queryArgs...).Scan(&rawResults).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
memories := make([]*model.MemoryListItem, len(rawResults))
|
||||
for i, r := range rawResults {
|
||||
memories[i] = &model.MemoryListItem{
|
||||
Memory: r.Memory,
|
||||
OwnerName: r.OwnerName,
|
||||
}
|
||||
}
|
||||
|
||||
return memories, total, nil
|
||||
}
|
||||
@ -141,3 +141,130 @@ func (dao *TenantLLMDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&model.TenantLLM{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// splitModelNameAndFactory splits model name and factory from combined format
|
||||
// This matches Python's split_model_name_and_factory logic
|
||||
//
|
||||
// Parameters:
|
||||
// - modelName: The model name which can be in format "ModelName" or "ModelName@Factory"
|
||||
//
|
||||
// Returns:
|
||||
// - string: The model name without factory prefix
|
||||
// - string: The factory name (empty string if not specified)
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// modelName, factory := splitModelNameAndFactory("gpt-4")
|
||||
// // Returns: "gpt-4", ""
|
||||
//
|
||||
// modelName, factory := splitModelNameAndFactory("gpt-4@OpenAI")
|
||||
// // Returns: "gpt-4", "OpenAI"
|
||||
func splitModelNameAndFactory(modelName string) (string, string) {
|
||||
// Split by "@" separator
|
||||
// Handle cases like "model@factory" or "model@sub@factory"
|
||||
lastAtIndex := -1
|
||||
for i := len(modelName) - 1; i >= 0; i-- {
|
||||
if modelName[i] == '@' {
|
||||
lastAtIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// No "@" found, return original name
|
||||
if lastAtIndex == -1 {
|
||||
return modelName, ""
|
||||
}
|
||||
|
||||
// Split into model name and potential factory
|
||||
modelNamePart := modelName[:lastAtIndex]
|
||||
factory := modelName[lastAtIndex+1:]
|
||||
|
||||
// Validate if factory exists in llm_factories table
|
||||
// This matches Python's logic of checking against model providers
|
||||
var factoryCount int64
|
||||
DB.Model(&model.LLMFactories{}).Where("name = ?", factory).Count(&factoryCount)
|
||||
|
||||
// If factory doesn't exist in database, treat the whole string as model name
|
||||
if factoryCount == 0 {
|
||||
return modelName, ""
|
||||
}
|
||||
|
||||
return modelNamePart, factory
|
||||
}
|
||||
|
||||
// GetByTenantIDAndLLMName gets tenant LLM by tenant ID and LLM name
|
||||
// This is used to resolve tenant_llm_id from llm_id
|
||||
// It supports both simple model names and factory-prefixed names (e.g., "gpt-4@OpenAI")
|
||||
//
|
||||
// Parameters:
|
||||
// - tenantID: The tenant identifier
|
||||
// - llmName: The LLM model name (can include factory prefix like "OpenAI@gpt-4")
|
||||
//
|
||||
// Returns:
|
||||
// - *model.TenantLLM: The tenant LLM record
|
||||
// - error: Error if not found
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Simple model name
|
||||
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4")
|
||||
//
|
||||
// // Model name with factory prefix
|
||||
// tenantLLM, err := dao.GetByTenantIDAndLLMName("tenant123", "gpt-4@OpenAI")
|
||||
func (dao *TenantLLMDAO) GetByTenantIDAndLLMName(tenantID string, llmName string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
|
||||
// Split model name and factory from the combined format
|
||||
modelName, factory := splitModelNameAndFactory(llmName)
|
||||
|
||||
// First attempt: try to find with model name only
|
||||
err := DB.Where("tenant_id = ? AND llm_name = ?", tenantID, modelName).First(&tenantLLM).Error
|
||||
if err == nil {
|
||||
return &tenantLLM, nil
|
||||
}
|
||||
|
||||
// Second attempt: if factory is specified, try with both model name and factory
|
||||
if factory != "" {
|
||||
err = DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, modelName, factory).First(&tenantLLM).Error
|
||||
if err == nil {
|
||||
return &tenantLLM, nil
|
||||
}
|
||||
|
||||
// Special handling for LocalAI and HuggingFace (matching Python logic)
|
||||
// These factories append "___FactoryName" to the model name
|
||||
if factory == "LocalAI" || factory == "HuggingFace" || factory == "OpenAI-API-Compatible" {
|
||||
specialModelName := modelName + "___" + factory
|
||||
err = DB.Where("tenant_id = ? AND llm_name = ?", tenantID, specialModelName).First(&tenantLLM).Error
|
||||
if err == nil {
|
||||
return &tenantLLM, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the last error (record not found)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// GetByTenantIDLLMNameAndFactory gets tenant LLM by tenant ID, LLM name and factory
|
||||
// This is used when model name includes factory suffix (e.g., "model@factory")
|
||||
//
|
||||
// Parameters:
|
||||
// - tenantID: The tenant identifier
|
||||
// - llmName: The LLM model name
|
||||
// - factory: The LLM factory name
|
||||
//
|
||||
// Returns:
|
||||
// - *model.TenantLLM: The tenant LLM record
|
||||
// - error: Error if not found
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// tenantLLM, err := dao.GetByTenantIDLLMNameAndFactory("tenant123", "gpt-4", "OpenAI")
|
||||
func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, factory string) (*model.TenantLLM, error) {
|
||||
var tenantLLM model.TenantLLM
|
||||
err := DB.Where("tenant_id = ? AND llm_name = ? AND llm_factory = ?", tenantID, llmName, factory).First(&tenantLLM).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tenantLLM, nil
|
||||
}
|
||||
|
||||
687
internal/handler/memory.go
Normal file
687
internal/handler/memory.go
Normal file
@ -0,0 +1,687 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package handler contains all HTTP request handlers
|
||||
// This file implements Memory-related API endpoint handlers
|
||||
// Each method corresponds to an API endpoint in the Python memory_api.py
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// MemoryHandler handles Memory-related HTTP requests
|
||||
// Responsible for processing all Memory-related HTTP requests
|
||||
// Each method corresponds to an API endpoint, implementing the same logic as Python memory_api.py
|
||||
type MemoryHandler struct {
|
||||
memoryService *service.MemoryService // Reference to Memory business service layer
|
||||
}
|
||||
|
||||
// NewMemoryHandler creates a new MemoryHandler instance
|
||||
//
|
||||
// Parameters:
|
||||
// - memoryService: Pointer to MemoryService business service layer
|
||||
//
|
||||
// Returns:
|
||||
// - *MemoryHandler: Initialized handler instance
|
||||
func NewMemoryHandler(memoryService *service.MemoryService) *MemoryHandler {
|
||||
return &MemoryHandler{
|
||||
memoryService: memoryService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMemory handles POST request for creating Memory
|
||||
// API Path: POST /api/v1/memories
|
||||
//
|
||||
// Function:
|
||||
// - Creates a new memory record
|
||||
// - Supports automatic system_prompt generation
|
||||
// - Supports name deduplication (if name exists, adds sequence number)
|
||||
//
|
||||
// Request Parameters (JSON Body):
|
||||
// - name (required): Memory name, max 128 characters
|
||||
// - memory_type (required): Memory type array, supports ["raw", "semantic", "episodic", "procedural"]
|
||||
// - embd_id (required): Embedding model ID
|
||||
// - llm_id (required): LLM model ID
|
||||
// - tenant_embd_id (optional): Tenant embedding model ID
|
||||
// - tenant_llm_id (optional): Tenant LLM model ID
|
||||
//
|
||||
// Response Format:
|
||||
// - code: Status code (0=success, other=error)
|
||||
// - message: true on success, error message on failure
|
||||
// - data: Memory object on success
|
||||
//
|
||||
// Business Logic (matching Python create_memory):
|
||||
// 1. Validate user login status
|
||||
// 2. Parse and validate request parameters
|
||||
// 3. Call service layer to create memory
|
||||
// 4. Return creation result
|
||||
func (h *MemoryHandler) CreateMemory(c *gin.Context) {
|
||||
// Check if API timing is enabled
|
||||
// If RAGFLOW_API_TIMING environment variable is set, request processing time will be logged
|
||||
timingEnabled := os.Getenv("RAGFLOW_API_TIMING")
|
||||
var tStart time.Time
|
||||
if timingEnabled != "" {
|
||||
tStart = time.Now()
|
||||
}
|
||||
|
||||
// Get current logged-in user information
|
||||
// GetUser is a context value set by the authentication middleware
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse JSON request body
|
||||
var req service.CreateMemoryRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required field: name
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "name is required",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required field: memory_type (must be non-empty array)
|
||||
if len(req.MemoryType) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "memory_type is required and must be a list",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required field: embd_id
|
||||
if req.EmbdID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "embd_id is required",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required field: llm_id
|
||||
if req.LLMID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "llm_id is required",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Record request parsing completion time (for timing)
|
||||
tParsed := time.Now()
|
||||
|
||||
// Call service layer to create memory
|
||||
result, err := h.memoryService.CreateMemory(userID, &req)
|
||||
if err != nil {
|
||||
// Log error if timing is enabled
|
||||
if timingEnabled != "" {
|
||||
totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0
|
||||
parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0
|
||||
_ = parseMs
|
||||
_ = totalMs
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
// Determine if it's an argument error and return appropriate error code
|
||||
if isArgumentError(errMsg) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Other errors return server error
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Log success if timing is enabled
|
||||
if timingEnabled != "" {
|
||||
totalMs := float64(time.Since(tStart).Microseconds()) / 1000.0
|
||||
parseMs := float64(tParsed.Sub(tStart).Microseconds()) / 1000.0
|
||||
validateAndDbMs := totalMs - parseMs
|
||||
_ = parseMs
|
||||
_ = validateAndDbMs
|
||||
_ = totalMs
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateMemory handles PUT request for updating Memory
|
||||
// API Path: PUT /api/v1/memories/:memory_id
|
||||
//
|
||||
// Function:
|
||||
// - Updates configuration information for the specified memory
|
||||
// - Supports partial updates: only update passed fields
|
||||
//
|
||||
// Request Parameters (JSON Body):
|
||||
// - name (optional): Memory name
|
||||
// - permissions (optional): Permission setting ["me", "team", "all"]
|
||||
// - llm_id (optional): LLM model ID
|
||||
// - embd_id (optional): Embedding model ID
|
||||
// - tenant_llm_id (optional): Tenant LLM model ID
|
||||
// - tenant_embd_id (optional): Tenant embedding model ID
|
||||
// - memory_type (optional): Memory type array
|
||||
// - memory_size (optional): Memory size, range (0, 5242880]
|
||||
// - forgetting_policy (optional): Forgetting policy, default "FIFO"
|
||||
// - temperature (optional): Temperature parameter, range [0, 1]
|
||||
// - avatar (optional): Avatar URL
|
||||
// - description (optional): Description
|
||||
// - system_prompt (optional): System prompt
|
||||
// - user_prompt (optional): User prompt
|
||||
//
|
||||
// Business Rules:
|
||||
// - name length <= 128 characters
|
||||
// - Cannot update tenant_embd_id, embd_id, memory_type when memory_size > 0
|
||||
// - When updating memory_type, system_prompt is automatically regenerated if it's the default
|
||||
func (h *MemoryHandler) UpdateMemory(c *gin.Context) {
|
||||
// Get memory_id from URL path
|
||||
memoryID := c.Param("memory_id")
|
||||
if memoryID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "memory_id is required",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get current logged-in user information
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse JSON request body
|
||||
var req service.UpdateMemoryRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service layer to update memory
|
||||
result, err := h.memoryService.UpdateMemory(userID, memoryID, &req)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
// Check if it's a "not found" error
|
||||
if strings.Contains(errMsg, "not found") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's an argument error
|
||||
if isArgumentError(errMsg) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Other errors return server error
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteMemory handles DELETE request for deleting Memory
|
||||
// API Path: DELETE /api/v1/memories/:memory_id
|
||||
//
|
||||
// Function:
|
||||
// - Deletes the specified memory record
|
||||
// - Also deletes associated message data
|
||||
//
|
||||
// Business Logic:
|
||||
// 1. Check if memory exists
|
||||
// 2. Delete memory record
|
||||
// 3. Delete associated message index
|
||||
func (h *MemoryHandler) DeleteMemory(c *gin.Context) {
|
||||
// Get memory_id from URL path
|
||||
memoryID := c.Param("memory_id")
|
||||
if memoryID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "memory_id is required",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service layer to delete memory
|
||||
err := h.memoryService.DeleteMemory(memoryID)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
// Check if it's a "not found" error
|
||||
if strings.Contains(errMsg, "not found") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Other errors return server error
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ListMemories handles GET request for listing Memories
|
||||
// API Path: GET /api/v1/memories
|
||||
//
|
||||
// Function:
|
||||
// - Lists memories accessible to the current user
|
||||
// - Supports multiple filter conditions
|
||||
// - Supports pagination and keyword search
|
||||
//
|
||||
// Query Parameters:
|
||||
// - memory_type (optional): Memory type filter, supports comma-separated multiple types
|
||||
// - tenant_id (optional): Tenant ID filter
|
||||
// - storage_type (optional): Storage type filter
|
||||
// - keywords (optional): Keyword search (fuzzy match on name)
|
||||
// - page (optional): Page number, default 1
|
||||
// - page_size (optional): Items per page, default 50
|
||||
//
|
||||
// Response Format:
|
||||
// - code: Status code
|
||||
// - message: true
|
||||
// - data.memory_list: Array of Memory objects
|
||||
// - data.total_count: Total record count
|
||||
func (h *MemoryHandler) ListMemories(c *gin.Context) {
|
||||
// Get current logged-in user information
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse query parameters
|
||||
memoryTypesParam := c.Query("memory_type")
|
||||
tenantIDsParam := c.Query("tenant_id")
|
||||
storageType := c.Query("storage_type")
|
||||
keywords := c.Query("keywords")
|
||||
pageStr := c.DefaultQuery("page", "1")
|
||||
pageSizeStr := c.DefaultQuery("page_size", "50")
|
||||
|
||||
// Convert pagination parameters to integers
|
||||
page, _ := strconv.Atoi(pageStr)
|
||||
pageSize, _ := strconv.Atoi(pageSizeStr)
|
||||
|
||||
// Validate pagination parameters
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 50
|
||||
}
|
||||
|
||||
// Parse memory_type parameter (supports comma separation)
|
||||
var memoryTypes []string
|
||||
if memoryTypesParam != "" {
|
||||
if strings.Contains(memoryTypesParam, ",") {
|
||||
memoryTypes = strings.Split(memoryTypesParam, ",")
|
||||
} else {
|
||||
memoryTypes = []string{memoryTypesParam}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tenant_id parameter
|
||||
// If not specified, service will get all tenants associated with the user
|
||||
var tenantIDs []string
|
||||
if tenantIDsParam != "" {
|
||||
if strings.Contains(tenantIDsParam, ",") {
|
||||
tenantIDs = strings.Split(tenantIDsParam, ",")
|
||||
} else {
|
||||
tenantIDs = []string{tenantIDsParam}
|
||||
}
|
||||
}
|
||||
|
||||
// Call service layer to get memory list
|
||||
result, err := h.memoryService.ListMemories(user.ID, tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": err.Error(),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// GetMemoryConfig handles GET request for getting Memory configuration
|
||||
// API Path: GET /api/v1/memories/:memory_id/config
|
||||
//
|
||||
// Function:
|
||||
// - Gets complete configuration information for the specified memory
|
||||
// - Includes owner name (obtained via JOIN with user table)
|
||||
//
|
||||
// Response Format:
|
||||
// - code: Status code
|
||||
// - message: true
|
||||
// - data: Memory object, including owner_name field
|
||||
func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) {
|
||||
// Get memory_id from URL path
|
||||
memoryID := c.Param("memory_id")
|
||||
if memoryID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "memory_id is required",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service layer to get memory configuration
|
||||
result, err := h.memoryService.GetMemoryConfig(memoryID)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
// Check if it's a "not found" error
|
||||
if strings.Contains(errMsg, "not found") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Other errors return server error
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": errMsg,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// GetMemoryMessages handles GET request for getting Memory messages
|
||||
// API Path: GET /api/v1/memories/:memory_id
|
||||
//
|
||||
// Function:
|
||||
// - Gets message list associated with the specified memory
|
||||
// - Supports filtering by agent_id
|
||||
// - Supports keyword search and pagination
|
||||
//
|
||||
// Query Parameters:
|
||||
// - agent_id (optional): Agent ID filter, supports comma-separated multiple
|
||||
// - keywords (optional): Keyword search
|
||||
// - page (optional): Page number, default 1
|
||||
// - page_size (optional): Items per page, default 50
|
||||
//
|
||||
// Response Format:
|
||||
// - code: Status code
|
||||
// - message: true
|
||||
// - data.messages: Array of message objects
|
||||
// - data.storage_type: Storage type
|
||||
//
|
||||
// TODO: Implementation pending - depends on CanvasService and TaskService
|
||||
func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// AddMessage handles POST request for adding messages
|
||||
// API Path: POST /api/v1/messages
|
||||
//
|
||||
// Function:
|
||||
// - Adds messages to one or more memories
|
||||
// - Messages will be embedded and saved to vector database
|
||||
// - Creates asynchronous task for processing
|
||||
//
|
||||
// Request Parameters (JSON Body):
|
||||
// - memory_id (required): Memory ID or ID array
|
||||
// - agent_id (required): Agent ID
|
||||
// - session_id (required): Session ID
|
||||
// - user_input (required): User input
|
||||
// - agent_response (required): Agent response
|
||||
// - user_id (optional): User ID
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) AddMessage(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "AddMessage not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ForgetMessage handles DELETE request for forgetting messages
|
||||
// API Path: DELETE /api/v1/messages/:memory_id/:message_id
|
||||
//
|
||||
// Function:
|
||||
// - Soft-deletes the specified message (sets forget_at timestamp)
|
||||
// - Message is not immediately deleted from database, but marked as "forgotten"
|
||||
//
|
||||
// Parameter Format:
|
||||
// - memory_id: Memory ID
|
||||
// - message_id: Message ID (integer)
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) ForgetMessage(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "ForgetMessage not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateMessage handles PUT request for updating message status
|
||||
// API Path: PUT /api/v1/messages/:memory_id/:message_id
|
||||
//
|
||||
// Function:
|
||||
// - Updates status of the specified message
|
||||
// - status is a boolean, converted to integer for storage (true=1, false=0)
|
||||
//
|
||||
// Request Parameters (JSON Body):
|
||||
// - status (required): Message status, boolean
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "UpdateMessage not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// SearchMessage handles GET request for searching messages
|
||||
// API Path: GET /api/v1/messages/search
|
||||
//
|
||||
// Function:
|
||||
// - Searches messages across multiple memories
|
||||
// - Supports vector similarity search and keyword search
|
||||
// - Fuses results from both search methods
|
||||
//
|
||||
// Query Parameters:
|
||||
// - memory_id (optional): Memory ID list, supports comma separation
|
||||
// - query (optional): Search query text
|
||||
// - similarity_threshold (optional): Similarity threshold, default 0.2
|
||||
// - keywords_similarity_weight (optional): Keyword weight, default 0.7
|
||||
// - top_n (optional): Number of results to return, default 5
|
||||
// - agent_id (optional): Agent ID filter
|
||||
// - session_id (optional): Session ID filter
|
||||
// - user_id (optional): User ID filter
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) SearchMessage(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "SearchMessage not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// GetMessages handles GET request for getting message list
|
||||
// API Path: GET /api/v1/messages
|
||||
//
|
||||
// Function:
|
||||
// - Gets recent messages from specified memories
|
||||
// - Supports filtering by agent_id and session_id
|
||||
//
|
||||
// Query Parameters:
|
||||
// - memory_id (required): Memory ID list, supports comma separation
|
||||
// - agent_id (optional): Agent ID filter
|
||||
// - session_id (optional): Session ID filter
|
||||
// - limit (optional): Number of results to return, default 10
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) GetMessages(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "GetMessages not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// GetMessageContent handles GET request for getting message content
|
||||
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
|
||||
//
|
||||
// Function:
|
||||
// - Gets complete content of the specified message
|
||||
// - doc_id format: memory_id + "_" + message_id
|
||||
//
|
||||
// Parameter Format:
|
||||
// - memory_id: Memory ID
|
||||
// - message_id: Message ID (integer)
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "GetMessageContent not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// isArgumentError determines if an error message is an argument error
|
||||
//
|
||||
// Function:
|
||||
// - Checks if the error message contains any argument validation-related prefixes
|
||||
// - Used to distinguish argument errors from server errors
|
||||
//
|
||||
// Parameters:
|
||||
// - msg: Error message string
|
||||
//
|
||||
// Returns:
|
||||
// - bool: true if it's an argument error, false otherwise
|
||||
func isArgumentError(msg string) bool {
|
||||
// Define list of argument error prefixes
|
||||
// Matches Python ArgumentException error messages
|
||||
argumentErrorPrefixes := []string{
|
||||
"memory name cannot be empty", // Memory name cannot be empty
|
||||
"memory name exceeds limit", // Memory name exceeds limit
|
||||
"memory type must be a list", // memory_type must be a list
|
||||
"memory type is not supported", // Unsupported memory_type
|
||||
}
|
||||
// Check if error message starts with any prefix
|
||||
for _, prefix := range argumentErrorPrefixes {
|
||||
if len(msg) >= len(prefix) && msg[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -42,3 +42,12 @@ type Memory struct {
|
||||
func (Memory) TableName() string {
|
||||
return "memory"
|
||||
}
|
||||
|
||||
// MemoryListItem represents a memory record with owner name from JOIN query.
|
||||
// Uses struct embedding to extend Memory struct with owner_name from user table JOIN.
|
||||
// Note: MemoryType is kept as int64 from Memory embedding; conversion to []string
|
||||
// happens in the Service layer via CreateMemoryResponse.
|
||||
type MemoryListItem struct {
|
||||
Memory
|
||||
OwnerName *string `json:"owner_name,omitempty"`
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ type Router struct {
|
||||
connectorHandler *handler.ConnectorHandler
|
||||
searchHandler *handler.SearchHandler
|
||||
fileHandler *handler.FileHandler
|
||||
memoryHandler *handler.MemoryHandler
|
||||
}
|
||||
|
||||
// NewRouter create router
|
||||
@ -56,6 +57,7 @@ func NewRouter(
|
||||
connectorHandler *handler.ConnectorHandler,
|
||||
searchHandler *handler.SearchHandler,
|
||||
fileHandler *handler.FileHandler,
|
||||
memoryHandler *handler.MemoryHandler,
|
||||
) *Router {
|
||||
return &Router{
|
||||
authHandler: authHandler,
|
||||
@ -72,6 +74,7 @@ func NewRouter(
|
||||
connectorHandler: connectorHandler,
|
||||
searchHandler: searchHandler,
|
||||
fileHandler: fileHandler,
|
||||
memoryHandler: memoryHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,6 +166,28 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
{
|
||||
authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID)
|
||||
}
|
||||
|
||||
// Memory routes
|
||||
memory := v1.Group("/memories")
|
||||
{
|
||||
memory.POST("", r.memoryHandler.CreateMemory)
|
||||
memory.PUT("/:memory_id", r.memoryHandler.UpdateMemory)
|
||||
memory.DELETE("/:memory_id", r.memoryHandler.DeleteMemory)
|
||||
memory.GET("", r.memoryHandler.ListMemories)
|
||||
memory.GET("/:memory_id/config", r.memoryHandler.GetMemoryConfig)
|
||||
memory.GET("/:memory_id", r.memoryHandler.GetMemoryMessages)
|
||||
}
|
||||
|
||||
// TODO: Message routes - Implementation pending - depends on CanvasService, TaskService and embedding engine
|
||||
// message := v1.Group("/messages")
|
||||
// {
|
||||
// message.POST("", r.memoryHandler.AddMessage)
|
||||
// message.DELETE("/:memory_id/:message_id", r.memoryHandler.ForgetMessage)
|
||||
// message.PUT("/:memory_id/:message_id", r.memoryHandler.UpdateMessage)
|
||||
// message.GET("/search", r.memoryHandler.SearchMessage)
|
||||
// message.GET("", r.memoryHandler.GetMessages)
|
||||
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
|
||||
// }
|
||||
}
|
||||
|
||||
// Knowledge base routes
|
||||
@ -260,6 +285,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
file.GET("/parent_folder", r.fileHandler.GetParentFolder)
|
||||
file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handle undefined routes
|
||||
|
||||
892
internal/service/memory.go
Normal file
892
internal/service/memory.go
Normal file
@ -0,0 +1,892 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// MemoryNameLimit is the maximum length allowed for memory names
|
||||
MemoryNameLimit = 128
|
||||
// MemorySizeLimit is the maximum memory size in bytes (5MB)
|
||||
MemorySizeLimit = 5242880
|
||||
)
|
||||
|
||||
// Note: MemoryType, MemoryTypeRaw, MemoryTypeSemantic, MemoryTypeEpisodic,
|
||||
// MemoryTypeProcedural, and CalculateMemoryType are defined in the dao package
|
||||
// and imported as dao.MemoryType, dao.MemoryTypeRaw, etc.
|
||||
|
||||
// TenantPermission defines the access permission levels for memory resources
|
||||
// Note: This type is specific to the service layer
|
||||
type TenantPermission string
|
||||
|
||||
const (
|
||||
// TenantPermissionMe restricts access to the owner only
|
||||
TenantPermissionMe TenantPermission = "me"
|
||||
// TenantPermissionTeam allows access within the same team
|
||||
TenantPermissionTeam TenantPermission = "team"
|
||||
// TenantPermissionAll allows access to all tenants
|
||||
TenantPermissionAll TenantPermission = "all"
|
||||
)
|
||||
|
||||
// validPermissions defines which permission values are valid
|
||||
var validPermissions = map[TenantPermission]bool{
|
||||
TenantPermissionMe: true,
|
||||
TenantPermissionTeam: true,
|
||||
TenantPermissionAll: true,
|
||||
}
|
||||
|
||||
// ForgettingPolicy defines the strategy for forgetting old memory entries
|
||||
type ForgettingPolicy string
|
||||
|
||||
const (
|
||||
// ForgettingPolicyFIFO uses First-In-First-Out strategy for forgetting
|
||||
ForgettingPolicyFIFO ForgettingPolicy = "FIFO"
|
||||
)
|
||||
|
||||
// validForgettingPolicies defines which forgetting policies are valid
|
||||
var validForgettingPolicies = map[ForgettingPolicy]bool{
|
||||
ForgettingPolicyFIFO: true,
|
||||
}
|
||||
|
||||
//
|
||||
// Note: CalculateMemoryType and GetMemoryTypeHuman functions have been moved to dao package
|
||||
// Use dao.CalculateMemoryType() and dao.GetMemoryTypeHuman() instead
|
||||
|
||||
// PromptAssembler handles the assembly of system prompts for memory extraction
|
||||
type PromptAssembler struct{}
|
||||
|
||||
// SYSTEM_BASE_TEMPLATE is the base template for the system prompt used in memory extraction
|
||||
// It includes placeholders for type-specific instructions, timestamp format, and max items
|
||||
var SYSTEM_BASE_TEMPLATE = `**Memory Extraction Specialist**
|
||||
You are an expert at analyzing conversations to extract structured memory.
|
||||
|
||||
{type_specific_instructions}
|
||||
|
||||
|
||||
**OUTPUT REQUIREMENTS:**
|
||||
1. Output MUST be valid JSON
|
||||
2. Follow the specified output format exactly
|
||||
3. Each extracted item MUST have: content, valid_at, invalid_at
|
||||
4. Timestamps in {timestamp_format} format
|
||||
5. Only extract memory types specified above
|
||||
6. Maximum {max_items} items per type
|
||||
`
|
||||
|
||||
// TYPE_INSTRUCTIONS contains specific instructions for each memory type extraction
|
||||
var TYPE_INSTRUCTIONS = map[string]string{
|
||||
"semantic": `
|
||||
**EXTRACT SEMANTIC KNOWLEDGE:**
|
||||
- Universal facts, definitions, concepts, relationships
|
||||
- Time-invariant, generally true information
|
||||
|
||||
**Timestamp Rules:**
|
||||
- valid_at: When the fact became true
|
||||
- invalid_at: When it becomes false or empty if still true
|
||||
`,
|
||||
"episodic": `
|
||||
**EXTRACT EPISODIC KNOWLEDGE:**
|
||||
- Specific experiences, events, personal stories
|
||||
- Time-bound, person-specific, contextual
|
||||
|
||||
**Timestamp Rules:**
|
||||
- valid_at: Event start/occurrence time
|
||||
- invalid_at: Event end time or empty if instantaneous
|
||||
`,
|
||||
"procedural": `
|
||||
**EXTRACT PROCEDURAL KNOWLEDGE:**
|
||||
- Processes, methods, step-by-step instructions
|
||||
- Goal-oriented, actionable, often includes conditions
|
||||
|
||||
**Timestamp Rules:**
|
||||
- valid_at: When procedure becomes valid/effective
|
||||
- invalid_at: When it expires/becomes obsolete or empty if current
|
||||
`,
|
||||
}
|
||||
|
||||
// OUTPUT_TEMPLATES defines the output format for each memory type
|
||||
var OUTPUT_TEMPLATES = map[string]string{
|
||||
"semantic": `"semantic": [{"content": "Clear factual statement", "valid_at": "timestamp or empty", "invalid_at": "timestamp or empty"}]`,
|
||||
"episodic": `"episodic": [{"content": "Narrative event description", "valid_at": "event start timestamp", "invalid_at": "event end timestamp or empty"}]`,
|
||||
"procedural": `"procedural": [{"content": "Actionable instructions", "valid_at": "procedure effective timestamp", "invalid_at": "procedure expiration timestamp or empty"}]`,
|
||||
}
|
||||
|
||||
// AssembleSystemPrompt generates a complete system prompt for memory extraction
|
||||
//
|
||||
// Parameters:
|
||||
// - memoryTypes: Array of memory type names to extract (e.g., ["semantic", "episodic"])
|
||||
//
|
||||
// Returns:
|
||||
// - string: Complete system prompt with type-specific instructions and output format
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// AssembleSystemPrompt([]string{"semantic", "episodic"}) returns a prompt with instructions
|
||||
// for both semantic and episodic memory extraction
|
||||
func (PromptAssembler) AssembleSystemPrompt(memoryTypes []string) string {
|
||||
typesToExtract := getTypesToExtract(memoryTypes)
|
||||
if len(typesToExtract) == 0 {
|
||||
typesToExtract = []string{"raw"}
|
||||
}
|
||||
|
||||
typeInstructions := generateTypeInstructions(typesToExtract)
|
||||
outputFormat := generateOutputFormat(typesToExtract)
|
||||
|
||||
fullPrompt := strings.Replace(SYSTEM_BASE_TEMPLATE, "{type_specific_instructions}", typeInstructions, 1)
|
||||
fullPrompt = strings.Replace(fullPrompt, "{timestamp_format}", "ISO 8601", 1)
|
||||
fullPrompt = strings.Replace(fullPrompt, "{max_items}", "5", 1)
|
||||
|
||||
fullPrompt += fmt.Sprintf("\n**REQUIRED OUTPUT FORMAT (JSON):\n```json\n{\n%s\n}\n```\n", outputFormat)
|
||||
|
||||
return fullPrompt
|
||||
}
|
||||
|
||||
// getTypesToExtract filters out "raw" type and returns valid memory types
|
||||
//
|
||||
// Parameters:
|
||||
// - requestedTypes: Array of requested memory type names
|
||||
//
|
||||
// Returns:
|
||||
// - []string: Filtered array of memory type names (excluding "raw")
|
||||
func getTypesToExtract(requestedTypes []string) []string {
|
||||
types := make(map[string]bool)
|
||||
for _, rt := range requestedTypes {
|
||||
lowerRT := strings.ToLower(rt)
|
||||
if lowerRT != "raw" {
|
||||
if _, ok := dao.MemoryTypeMap[lowerRT]; ok {
|
||||
types[lowerRT] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]string, 0, len(types))
|
||||
for t := range types {
|
||||
result = append(result, t)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// generateTypeInstructions concatenates type-specific instructions
|
||||
//
|
||||
// Parameters:
|
||||
// - typesToExtract: Array of memory type names
|
||||
//
|
||||
// Returns:
|
||||
// - string: Concatenated instructions for all specified types
|
||||
func generateTypeInstructions(typesToExtract []string) string {
|
||||
var instructions []string
|
||||
for _, mt := range typesToExtract {
|
||||
if instr, ok := TYPE_INSTRUCTIONS[mt]; ok {
|
||||
instructions = append(instructions, instr)
|
||||
}
|
||||
}
|
||||
return strings.Join(instructions, "\n")
|
||||
}
|
||||
|
||||
// generateOutputFormat concatenates output format templates
|
||||
//
|
||||
// Parameters:
|
||||
// - typesToExtract: Array of memory type names
|
||||
//
|
||||
// Returns:
|
||||
// - string: Concatenated output format templates
|
||||
func generateOutputFormat(typesToExtract []string) string {
|
||||
var outputParts []string
|
||||
for _, mt := range typesToExtract {
|
||||
if tmpl, ok := OUTPUT_TEMPLATES[mt]; ok {
|
||||
outputParts = append(outputParts, tmpl)
|
||||
}
|
||||
}
|
||||
return strings.Join(outputParts, ",\n")
|
||||
}
|
||||
|
||||
// MemoryService handles business logic for memory operations
|
||||
// It provides methods for creating, updating, deleting, and querying memories
|
||||
type MemoryService struct {
|
||||
memoryDAO *dao.MemoryDAO
|
||||
}
|
||||
|
||||
// NewMemoryService creates a new MemoryService instance
|
||||
//
|
||||
// Returns:
|
||||
// - *MemoryService: Initialized service instance with DAO
|
||||
func NewMemoryService() *MemoryService {
|
||||
return &MemoryService{
|
||||
memoryDAO: dao.NewMemoryDAO(),
|
||||
}
|
||||
}
|
||||
|
||||
// splitNameCounter splits a filename into base name and counter
|
||||
// Handles names in format "filename(123)" pattern
|
||||
//
|
||||
// Parameters:
|
||||
// - filename: The filename to split
|
||||
//
|
||||
// Returns:
|
||||
// - string: The base name without counter
|
||||
// - *int: The counter value, or nil if no counter exists
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// splitNameCounter("test(5)") returns ("test", 5)
|
||||
// splitNameCounter("test") returns ("test", nil)
|
||||
func splitNameCounter(filename string) (string, *int) {
|
||||
re := regexp.MustCompile(`^(.+)\((\d+)\)$`)
|
||||
matches := re.FindStringSubmatch(filename)
|
||||
if len(matches) >= 3 {
|
||||
counter := -1
|
||||
fmt.Sscanf(matches[2], "%d", &counter)
|
||||
stem := strings.TrimRight(matches[1], " ")
|
||||
return stem, &counter
|
||||
}
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
// duplicateName generates a unique name by appending a counter if the name already exists
|
||||
// It tries up to 1000 times to generate a unique name
|
||||
//
|
||||
// Parameters:
|
||||
// - queryFunc: Function to check if a name already exists (returns true if exists)
|
||||
// - name: The original name
|
||||
// - tenantID: The tenant ID for name uniqueness check
|
||||
//
|
||||
// Returns:
|
||||
// - string: A unique name (either original or with counter appended)
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// duplicateName(func(name string, tid string) bool { return false }, "test", "tenant1") returns "test"
|
||||
// duplicateName(func(name string, tid string) bool { return true }, "test", "tenant1") returns "test(1)"
|
||||
func duplicateName(queryFunc func(name string, tenantID string) bool, name string, tenantID string) string {
|
||||
const maxRetries = 1000
|
||||
|
||||
originalName := name
|
||||
currentName := name
|
||||
retries := 0
|
||||
|
||||
for retries < maxRetries {
|
||||
if !queryFunc(currentName, tenantID) {
|
||||
return currentName
|
||||
}
|
||||
|
||||
stem, counter := splitNameCounter(currentName)
|
||||
ext := path.Ext(stem)
|
||||
stemBase := strings.TrimSuffix(stem, ext)
|
||||
|
||||
newCounter := 1
|
||||
if counter != nil {
|
||||
newCounter = *counter + 1
|
||||
}
|
||||
|
||||
currentName = fmt.Sprintf("%s(%d)%s", stemBase, newCounter, ext)
|
||||
retries++
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("Failed to generate unique name within %d attempts. Original: %s", maxRetries, originalName))
|
||||
}
|
||||
|
||||
// CreateMemoryRequest defines the request structure for creating a memory
|
||||
type CreateMemoryRequest struct {
|
||||
// Name is the memory name (required, max 128 characters)
|
||||
Name string `json:"name" binding:"required"`
|
||||
// MemoryType is the array of memory type names (required)
|
||||
MemoryType []string `json:"memory_type" binding:"required"`
|
||||
// EmbdID is the embedding model ID (required)
|
||||
EmbdID string `json:"embd_id" binding:"required"`
|
||||
// LLMID is the language model ID (required)
|
||||
LLMID string `json:"llm_id" binding:"required"`
|
||||
// TenantEmbdID is the tenant-specific embedding model ID (optional)
|
||||
TenantEmbdID *string `json:"tenant_embd_id"`
|
||||
// TenantLLMID is the tenant-specific language model ID (optional)
|
||||
TenantLLMID *string `json:"tenant_llm_id"`
|
||||
}
|
||||
|
||||
// UpdateMemoryRequest defines the request structure for updating a memory
|
||||
// All fields are optional, only provided fields will be updated
|
||||
type UpdateMemoryRequest struct {
|
||||
// Name is the new memory name (optional)
|
||||
Name *string `json:"name"`
|
||||
// Permissions is the new permission level (optional)
|
||||
Permissions *string `json:"permissions"`
|
||||
// LLMID is the new language model ID (optional)
|
||||
LLMID *string `json:"llm_id"`
|
||||
// EmbdID is the new embedding model ID (optional)
|
||||
EmbdID *string `json:"embd_id"`
|
||||
// TenantLLMID is the new tenant-specific language model ID (optional)
|
||||
TenantLLMID *string `json:"tenant_llm_id"`
|
||||
// TenantEmbdID is the new tenant-specific embedding model ID (optional)
|
||||
TenantEmbdID *string `json:"tenant_embd_id"`
|
||||
// MemoryType is the new array of memory type names (optional)
|
||||
MemoryType []string `json:"memory_type"`
|
||||
// MemorySize is the new memory size in bytes (optional, max 5MB)
|
||||
MemorySize *int64 `json:"memory_size"`
|
||||
// ForgettingPolicy is the new forgetting policy (optional)
|
||||
ForgettingPolicy *string `json:"forgetting_policy"`
|
||||
// Temperature is the new temperature value (optional, range [0, 1])
|
||||
Temperature *float64 `json:"temperature"`
|
||||
// Avatar is the new avatar URL (optional)
|
||||
Avatar *string `json:"avatar"`
|
||||
// Description is the new description (optional)
|
||||
Description *string `json:"description"`
|
||||
// SystemPrompt is the new system prompt (optional)
|
||||
SystemPrompt *string `json:"system_prompt"`
|
||||
// UserPrompt is the new user prompt (optional)
|
||||
UserPrompt *string `json:"user_prompt"`
|
||||
}
|
||||
|
||||
// CreateMemoryResponse defines the response structure for memory operations
|
||||
// Uses struct embedding to extend Memory struct with API-specific fields
|
||||
type CreateMemoryResponse struct {
|
||||
model.Memory
|
||||
OwnerName *string `json:"owner_name,omitempty"`
|
||||
MemoryType []string `json:"memory_type"`
|
||||
}
|
||||
|
||||
// ListMemoryResponse defines the response structure for listing memories
|
||||
type ListMemoryResponse struct {
|
||||
// MemoryList is the array of memory objects
|
||||
MemoryList []map[string]interface{} `json:"memory_list"`
|
||||
// TotalCount is the total number of memories
|
||||
TotalCount int64 `json:"total_count"`
|
||||
}
|
||||
|
||||
// CreateMemory creates a new memory with the given parameters
|
||||
// It validates the request, generates a unique name if needed, and creates the memory record
|
||||
//
|
||||
// Parameters:
|
||||
// - tenantID: The tenant ID for which to create the memory
|
||||
// - req: The memory creation request containing name, memory_type, embd_id, llm_id, etc.
|
||||
//
|
||||
// Returns:
|
||||
// - *CreateMemoryResponse: The created memory details
|
||||
// - error: Error if validation fails or creation fails
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// req := &CreateMemoryRequest{Name: "MyMemory", MemoryType: []string{"semantic"}, EmbdID: "embd1", LLMID: "llm1"}
|
||||
// resp, err := service.CreateMemory("tenant123", req)
|
||||
func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) (*CreateMemoryResponse, error) {
|
||||
// Ensure tenant model IDs are populated for LLM and embedding model parameters
|
||||
// This automatically fills tenant_llm_id and tenant_embd_id based on llm_id and embd_id
|
||||
tenantLLMService := NewTenantLLMService()
|
||||
params := map[string]interface{}{
|
||||
"llm_id": req.LLMID,
|
||||
"embd_id": req.EmbdID,
|
||||
}
|
||||
params = tenantLLMService.EnsureTenantModelIDForParams(tenantID, params)
|
||||
|
||||
// Update request with tenant model IDs from the processed params
|
||||
if tenantLLMID, ok := params["tenant_llm_id"].(int64); ok {
|
||||
tenantLLMIDStr := strconv.FormatInt(tenantLLMID, 10)
|
||||
req.TenantLLMID = &tenantLLMIDStr
|
||||
}
|
||||
if tenantEmbdID, ok := params["tenant_embd_id"].(int64); ok {
|
||||
tenantEmbdIDStr := strconv.FormatInt(tenantEmbdID, 10)
|
||||
req.TenantEmbdID = &tenantEmbdIDStr
|
||||
}
|
||||
|
||||
memoryName := strings.TrimSpace(req.Name)
|
||||
if len(memoryName) == 0 {
|
||||
return nil, errors.New("memory name cannot be empty or whitespace")
|
||||
}
|
||||
if len(memoryName) > MemoryNameLimit {
|
||||
return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||
}
|
||||
|
||||
if !isList(req.MemoryType) {
|
||||
return nil, errors.New("memory type must be a list")
|
||||
}
|
||||
|
||||
memoryTypeSet := make(map[string]bool)
|
||||
for _, mt := range req.MemoryType {
|
||||
lowerMT := strings.ToLower(mt)
|
||||
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
|
||||
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
|
||||
}
|
||||
memoryTypeSet[lowerMT] = true
|
||||
}
|
||||
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
|
||||
for mt := range memoryTypeSet {
|
||||
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
|
||||
}
|
||||
|
||||
memoryName = duplicateName(func(name string, tid string) bool {
|
||||
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
|
||||
return len(existing) > 0
|
||||
}, memoryName, tenantID)
|
||||
|
||||
if len(memoryName) > MemoryNameLimit {
|
||||
return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||
}
|
||||
|
||||
memoryTypeInt := dao.CalculateMemoryType(uniqueMemoryTypes)
|
||||
timestamp := time.Now().UnixMilli()
|
||||
|
||||
systemPrompt := PromptAssembler{}.AssembleSystemPrompt(uniqueMemoryTypes)
|
||||
|
||||
newID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
if len(newID) > 32 {
|
||||
newID = newID[:32]
|
||||
}
|
||||
|
||||
memory := &model.Memory{
|
||||
ID: newID,
|
||||
Name: memoryName,
|
||||
TenantID: tenantID,
|
||||
MemoryType: memoryTypeInt,
|
||||
StorageType: "table",
|
||||
EmbdID: req.EmbdID,
|
||||
LLMID: req.LLMID,
|
||||
Permissions: "me",
|
||||
MemorySize: MemorySizeLimit,
|
||||
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
||||
Temperature: 0.5,
|
||||
SystemPrompt: &systemPrompt,
|
||||
}
|
||||
|
||||
// Convert tenant model IDs from string to int64 for database
|
||||
if req.TenantEmbdID != nil {
|
||||
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
|
||||
memory.TenantEmbdID = &embdID
|
||||
}
|
||||
}
|
||||
if req.TenantLLMID != nil {
|
||||
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
|
||||
memory.TenantLLMID = &llmID
|
||||
}
|
||||
}
|
||||
memory.CreateTime = ×tamp
|
||||
memory.UpdateTime = ×tamp
|
||||
|
||||
if err := s.memoryDAO.Create(memory); err != nil {
|
||||
return nil, errors.New("could not create new memory")
|
||||
}
|
||||
|
||||
createdMemory, err := s.memoryDAO.GetByID(newID)
|
||||
if err != nil {
|
||||
return nil, errors.New("could not create new memory")
|
||||
}
|
||||
|
||||
return formatRetDataFromMemory(createdMemory), nil
|
||||
}
|
||||
|
||||
// UpdateMemory updates an existing memory with the provided fields
|
||||
// Only the fields specified in the request will be updated (partial update)
|
||||
//
|
||||
// Parameters:
|
||||
// - tenantID: The tenant ID for ownership verification
|
||||
// - memoryID: The ID of the memory to update
|
||||
// - req: The update request with optional fields to update
|
||||
//
|
||||
// Returns:
|
||||
// - *CreateMemoryResponse: The updated memory details
|
||||
// - error: Error if validation fails or update fails
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// req := &UpdateMemoryRequest{Name: ptr("NewName"), MemorySize: ptr(int64(1000000))}
|
||||
// resp, err := service.UpdateMemory("tenant123", "memory456", req)
|
||||
func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *UpdateMemoryRequest) (*CreateMemoryResponse, error) {
|
||||
updateDict := make(map[string]interface{})
|
||||
|
||||
if req.Name != nil {
|
||||
memoryName := strings.TrimSpace(*req.Name)
|
||||
if len(memoryName) == 0 {
|
||||
return nil, errors.New("memory name cannot be empty or whitespace")
|
||||
}
|
||||
if len(memoryName) > MemoryNameLimit {
|
||||
return nil, fmt.Errorf("memory name '%s' exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||
}
|
||||
memoryName = duplicateName(func(name string, tid string) bool {
|
||||
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
|
||||
return len(existing) > 0
|
||||
}, memoryName, tenantID)
|
||||
if len(memoryName) > MemoryNameLimit {
|
||||
return nil, fmt.Errorf("memory name %s exceeds limit of %d", memoryName, MemoryNameLimit)
|
||||
}
|
||||
updateDict["name"] = memoryName
|
||||
}
|
||||
|
||||
if req.Permissions != nil {
|
||||
perm := TenantPermission(strings.ToLower(*req.Permissions))
|
||||
if !validPermissions[perm] {
|
||||
return nil, fmt.Errorf("unknown permission '%s'", *req.Permissions)
|
||||
}
|
||||
updateDict["permissions"] = perm
|
||||
}
|
||||
|
||||
if req.LLMID != nil {
|
||||
updateDict["llm_id"] = *req.LLMID
|
||||
}
|
||||
|
||||
if req.EmbdID != nil {
|
||||
updateDict["embd_id"] = *req.EmbdID
|
||||
}
|
||||
|
||||
if req.TenantLLMID != nil {
|
||||
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
|
||||
updateDict["tenant_llm_id"] = llmID
|
||||
}
|
||||
}
|
||||
|
||||
if req.TenantEmbdID != nil {
|
||||
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
|
||||
updateDict["tenant_embd_id"] = embdID
|
||||
}
|
||||
}
|
||||
|
||||
if req.MemoryType != nil && len(req.MemoryType) > 0 {
|
||||
memoryTypeSet := make(map[string]bool)
|
||||
for _, mt := range req.MemoryType {
|
||||
lowerMT := strings.ToLower(mt)
|
||||
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
|
||||
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
|
||||
}
|
||||
memoryTypeSet[lowerMT] = true
|
||||
}
|
||||
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
|
||||
for mt := range memoryTypeSet {
|
||||
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
|
||||
}
|
||||
updateDict["memory_type"] = uniqueMemoryTypes
|
||||
}
|
||||
|
||||
if req.MemorySize != nil {
|
||||
memorySize := *req.MemorySize
|
||||
if !(memorySize > 0 && memorySize <= MemorySizeLimit) {
|
||||
return nil, fmt.Errorf("memory size should be in range (0, %d] Bytes", MemorySizeLimit)
|
||||
}
|
||||
updateDict["memory_size"] = memorySize
|
||||
}
|
||||
|
||||
if req.ForgettingPolicy != nil {
|
||||
fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy))
|
||||
if !validForgettingPolicies[fp] {
|
||||
return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy)
|
||||
}
|
||||
updateDict["forgetting_policy"] = fp
|
||||
}
|
||||
|
||||
if req.Temperature != nil {
|
||||
temp := *req.Temperature
|
||||
if !(temp >= 0 && temp <= 1) {
|
||||
return nil, errors.New("temperature should be in range [0, 1]")
|
||||
}
|
||||
updateDict["temperature"] = temp
|
||||
}
|
||||
|
||||
for _, field := range []string{"avatar", "description", "system_prompt", "user_prompt"} {
|
||||
switch field {
|
||||
case "avatar":
|
||||
if req.Avatar != nil {
|
||||
updateDict["avatar"] = *req.Avatar
|
||||
}
|
||||
case "description":
|
||||
if req.Description != nil {
|
||||
updateDict["description"] = *req.Description
|
||||
}
|
||||
case "system_prompt":
|
||||
if req.SystemPrompt != nil {
|
||||
updateDict["system_prompt"] = *req.SystemPrompt
|
||||
}
|
||||
case "user_prompt":
|
||||
if req.UserPrompt != nil {
|
||||
updateDict["user_prompt"] = *req.UserPrompt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentMemory, err := s.memoryDAO.GetByID(memoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory '%s' not found", memoryID)
|
||||
}
|
||||
|
||||
if len(updateDict) == 0 {
|
||||
return formatRetDataFromMemory(currentMemory), nil
|
||||
}
|
||||
|
||||
memorySize := currentMemory.MemorySize
|
||||
notAllowedUpdate := []string{}
|
||||
for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} {
|
||||
if _, ok := updateDict[f]; ok && memorySize > 0 {
|
||||
notAllowedUpdate = append(notAllowedUpdate, f)
|
||||
}
|
||||
}
|
||||
if len(notAllowedUpdate) > 0 {
|
||||
return nil, fmt.Errorf("can't update %v when memory isn't empty", notAllowedUpdate)
|
||||
}
|
||||
|
||||
if _, ok := updateDict["memory_type"]; ok {
|
||||
if _, ok := updateDict["system_prompt"]; !ok {
|
||||
memoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
|
||||
if len(memoryTypes) > 0 && currentMemory.SystemPrompt != nil {
|
||||
defaultPrompt := PromptAssembler{}.AssembleSystemPrompt(memoryTypes)
|
||||
if *currentMemory.SystemPrompt == defaultPrompt {
|
||||
if types, ok := updateDict["memory_type"].([]string); ok {
|
||||
updateDict["system_prompt"] = PromptAssembler{}.AssembleSystemPrompt(types)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.memoryDAO.UpdateByID(memoryID, updateDict); err != nil {
|
||||
return nil, errors.New("failed to update memory")
|
||||
}
|
||||
|
||||
updatedMemory, err := s.memoryDAO.GetByID(memoryID)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get updated memory")
|
||||
}
|
||||
|
||||
return formatRetDataFromMemory(updatedMemory), nil
|
||||
}
|
||||
|
||||
// DeleteMemory deletes a memory by ID
|
||||
// It also deletes associated message indexes before removing the memory record
|
||||
//
|
||||
// Parameters:
|
||||
// - memoryID: The ID of the memory to delete
|
||||
//
|
||||
// Returns:
|
||||
// - error: Error if memory not found or deletion fails
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := service.DeleteMemory("memory456")
|
||||
func (s *MemoryService) DeleteMemory(memoryID string) error {
|
||||
_, err := s.memoryDAO.GetByID(memoryID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory '%s' not found", memoryID)
|
||||
}
|
||||
|
||||
// TODO: Delete associated message index - Implementation pending MessageService
|
||||
// messageService := NewMessageService()
|
||||
// hasIndex, _ := messageService.HasIndex(memory.TenantID, memoryID)
|
||||
// if hasIndex {
|
||||
// messageService.DeleteMessage(nil, memory.TenantID, memoryID)
|
||||
// }
|
||||
|
||||
// Delete memory record
|
||||
if err := s.memoryDAO.DeleteByID(memoryID); err != nil {
|
||||
return errors.New("failed to delete memory")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListMemories retrieves a paginated list of memories with optional filters
|
||||
// When tenantIDs is empty, it retrieves all tenants associated with the user
|
||||
//
|
||||
// Parameters:
|
||||
// - userID: The user ID for tenant filtering when tenantIDs is empty
|
||||
// - tenantIDs: Array of tenant IDs to filter by (empty means all user's tenants)
|
||||
// - memoryTypes: Array of memory type names to filter by (empty means all types)
|
||||
// - storageType: Storage type to filter by (empty means all types)
|
||||
// - keywords: Keywords to search in memory names (empty means no keyword filter)
|
||||
// - page: Page number (1-based)
|
||||
// - pageSize: Number of items per page
|
||||
//
|
||||
// Returns:
|
||||
// - *ListMemoryResponse: Contains memory list and total count
|
||||
// - error: Error if query fails
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resp, err := service.ListMemories("user123", []string{}, []string{"semantic"}, "table", "test", 1, 10)
|
||||
func (s *MemoryService) ListMemories(userID string, tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) (*ListMemoryResponse, error) {
|
||||
// If tenantIDs is empty, get all tenants associated with the user
|
||||
if len(tenantIDs) == 0 {
|
||||
userTenantService := NewUserTenantService()
|
||||
userTenants, err := userTenantService.GetUserTenantRelationByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user tenants: %w", err)
|
||||
}
|
||||
tenantIDs = make([]string, len(userTenants))
|
||||
for i, tenant := range userTenants {
|
||||
tenantIDs[i] = tenant.TenantID
|
||||
}
|
||||
}
|
||||
|
||||
memories, total, err := s.memoryDAO.GetByFilter(tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoryList := make([]map[string]interface{}, 0, len(memories))
|
||||
for _, m := range memories {
|
||||
resp := formatRetDataFromMemoryListItem(m)
|
||||
var createDateStr *string
|
||||
if resp.CreateTime != nil {
|
||||
createDateStr = formatDateToString(*resp.CreateTime)
|
||||
}
|
||||
memoryMap := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"name": resp.Name,
|
||||
"avatar": resp.Avatar,
|
||||
"tenant_id": resp.TenantID,
|
||||
"owner_name": resp.OwnerName,
|
||||
"memory_type": resp.MemoryType,
|
||||
"storage_type": resp.StorageType,
|
||||
"permissions": resp.Permissions,
|
||||
"description": resp.Description,
|
||||
"create_time": resp.CreateTime,
|
||||
"create_date": createDateStr,
|
||||
}
|
||||
memoryList = append(memoryList, memoryMap)
|
||||
}
|
||||
|
||||
return &ListMemoryResponse{
|
||||
MemoryList: memoryList,
|
||||
TotalCount: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetMemoryConfig retrieves the full configuration of a memory by ID
|
||||
//
|
||||
// Parameters:
|
||||
// - memoryID: The ID of the memory to retrieve
|
||||
//
|
||||
// Returns:
|
||||
// - *CreateMemoryResponse: The memory configuration details
|
||||
// - error: Error if memory not found
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resp, err := service.GetMemoryConfig("memory456")
|
||||
func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, error) {
|
||||
memory, err := s.memoryDAO.GetWithOwnerNameByID(memoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory '%s' not found", memoryID)
|
||||
}
|
||||
return formatRetDataFromMemoryListItem(memory), nil
|
||||
}
|
||||
|
||||
// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||
// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... }
|
||||
|
||||
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
|
||||
|
||||
// TODO: AddMessage - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... }
|
||||
|
||||
// TODO: ForgetMessage - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) ForgetMessage(memoryID string, messageID int) (bool, error) { ... }
|
||||
|
||||
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
|
||||
|
||||
// TODO: SearchMessage - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... }
|
||||
|
||||
// TODO: GetMessages - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... }
|
||||
|
||||
// TODO: GetMessageContent - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... }
|
||||
|
||||
// isList checks if a value is a list or array type
|
||||
// This is a utility function for type validation
|
||||
//
|
||||
// Parameters:
|
||||
// - v: The value to check
|
||||
//
|
||||
// Returns:
|
||||
// - bool: true if v is []interface{} or []string, false otherwise
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// isList([]string{"a", "b"}) returns true
|
||||
// isList("test") returns false
|
||||
func isList(v interface{}) bool {
|
||||
switch v.(type) {
|
||||
case []interface{}, []string:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// formatRetDataFromMemory converts a Memory model to CreateMemoryResponse format
|
||||
// This is a utility function for formatting memory data for API responses
|
||||
//
|
||||
// Parameters:
|
||||
// - memory: The Memory model to format
|
||||
//
|
||||
// Returns:
|
||||
// - *CreateMemoryResponse: Formatted memory response with human-readable types and dates
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resp := formatRetDataFromMemory(memoryModel)
|
||||
func formatRetDataFromMemory(memory *model.Memory) *CreateMemoryResponse {
|
||||
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
||||
|
||||
resp := &CreateMemoryResponse{
|
||||
Memory: *memory,
|
||||
OwnerName: nil,
|
||||
MemoryType: memoryTypes,
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func formatDateToString(t int64) *string {
|
||||
if t == 0 {
|
||||
return nil
|
||||
}
|
||||
// Database stores timestamps in milliseconds, convert to seconds
|
||||
if t > 1e10 {
|
||||
t = t / 1000
|
||||
}
|
||||
timeObj := time.Unix(t, 0)
|
||||
s := timeObj.Format("2006-01-02 15:04:05")
|
||||
return &s
|
||||
}
|
||||
|
||||
// formatRetDataFromMemoryListItem converts a MemoryListItem to CreateMemoryResponse
|
||||
// This function is used for both list and detail memory responses where owner_name is from JOIN query
|
||||
//
|
||||
// Parameters:
|
||||
// - memory: MemoryListItem pointer with owner_name from JOIN
|
||||
//
|
||||
// Returns:
|
||||
// - *CreateMemoryResponse: Formatted response with owner_name populated
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resp := formatRetDataFromMemoryListItem(memoryItem)
|
||||
func formatRetDataFromMemoryListItem(memory *model.MemoryListItem) *CreateMemoryResponse {
|
||||
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
||||
resp := &CreateMemoryResponse{
|
||||
Memory: memory.Memory,
|
||||
OwnerName: memory.OwnerName,
|
||||
MemoryType: memoryTypes,
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@ -17,12 +17,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/model"
|
||||
"ragflow/internal/engine"
|
||||
)
|
||||
|
||||
@ -92,6 +94,136 @@ type TenantListItem struct {
|
||||
DeltaSeconds float64 `json:"delta_seconds"`
|
||||
}
|
||||
|
||||
// TenantLLMService tenant LLM service
|
||||
// This service handles operations related to tenant-specific LLM configurations
|
||||
type TenantLLMService struct {
|
||||
tenantLLMDAO *dao.TenantLLMDAO
|
||||
}
|
||||
|
||||
// NewTenantLLMService creates a new TenantLLMService instance
|
||||
func NewTenantLLMService() *TenantLLMService {
|
||||
return &TenantLLMService{
|
||||
tenantLLMDAO: dao.NewTenantLLMDAO(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAPIKey retrieves the tenant LLM record by tenant ID and model name
|
||||
/**
|
||||
* This method splits the model name into name and factory parts using the "@" separator,
|
||||
* then queries the database for the matching tenant LLM configuration.
|
||||
*
|
||||
* Parameters:
|
||||
* - tenantID: the unique identifier of the tenant
|
||||
* - modelName: the model name, optionally including factory suffix (e.g., "gpt-4@OpenAI")
|
||||
*
|
||||
* Returns:
|
||||
* - *model.TenantLLM: the tenant LLM record if found, nil otherwise
|
||||
* - error: an error if the query fails, nil otherwise
|
||||
*
|
||||
* Example:
|
||||
*
|
||||
* service := NewTenantLLMService()
|
||||
*
|
||||
* // Get API key for model with factory
|
||||
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4@OpenAI")
|
||||
* if err != nil {
|
||||
* log.Printf("Error: %v", err)
|
||||
* }
|
||||
*
|
||||
* // Get API key for model without factory
|
||||
* tenantLLM, err := service.GetAPIKey("tenant-123", "gpt-4")
|
||||
*/
|
||||
func (s *TenantLLMService) GetAPIKey(tenantID, modelName string) (*model.TenantLLM, error) {
|
||||
modelName, factory := s.SplitModelNameAndFactory(modelName)
|
||||
|
||||
var tenantLLM *model.TenantLLM
|
||||
var err error
|
||||
|
||||
if factory == "" {
|
||||
tenantLLM, err = s.tenantLLMDAO.GetByTenantIDAndLLMName(tenantID, modelName)
|
||||
} else {
|
||||
tenantLLM, err = s.tenantLLMDAO.GetByTenantIDLLMNameAndFactory(tenantID, modelName, factory)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tenantLLM, nil
|
||||
}
|
||||
|
||||
// SplitModelNameAndFactory splits a model name into name and factory parts
|
||||
func (s *TenantLLMService) SplitModelNameAndFactory(modelName string) (string, string) {
|
||||
arr := strings.Split(modelName, "@")
|
||||
if len(arr) < 2 {
|
||||
return modelName, ""
|
||||
}
|
||||
if len(arr) > 2 {
|
||||
return strings.Join(arr[0:len(arr)-1], "@"), arr[len(arr)-1]
|
||||
}
|
||||
return arr[0], arr[1]
|
||||
}
|
||||
|
||||
// EnsureTenantModelIDForParams ensures tenant model IDs are populated for LLM-related parameters
|
||||
/**
|
||||
* This method iterates through a predefined list of LLM-related parameter keys (llm_id, embd_id,
|
||||
* asr_id, img2txt_id, rerank_id, tts_id) and automatically populates the corresponding tenant_*
|
||||
* fields (tenant_llm_id, tenant_embd_id, etc.) with the tenant LLM record IDs.
|
||||
*
|
||||
* If a parameter key exists and its corresponding tenant_* key doesn't exist, this method will:
|
||||
* 1. Query the tenant LLM record using GetAPIKey
|
||||
* 2. If found, set the tenant_* key to the record's ID
|
||||
* 3. If not found, set the tenant_* key to 0
|
||||
*
|
||||
* Parameters:
|
||||
* - tenantID: the unique identifier of the tenant
|
||||
* - params: a map of parameters to be updated (will be modified in place)
|
||||
*
|
||||
* Returns:
|
||||
* - map[string]interface{}: the updated parameters map (same as input, modified in place)
|
||||
*
|
||||
* Example:
|
||||
*
|
||||
* service := NewTenantLLMService()
|
||||
* params := map[string]interface{}{
|
||||
* "llm_id": "gpt-4@OpenAI",
|
||||
* "embd_id": "text-embedding-3-small@OpenAI",
|
||||
* }
|
||||
* result := service.EnsureTenantModelIDForParams("tenant-123", params)
|
||||
* // result will contain:
|
||||
* // {
|
||||
* // "llm_id": "gpt-4@OpenAI",
|
||||
* // "embd_id": "text-embedding-3-small@OpenAI",
|
||||
* // "tenant_llm_id": 123, // ID from tenant_llm table
|
||||
* // "tenant_embd_id": 456, // ID from tenant_llm table
|
||||
* // }
|
||||
*/
|
||||
func (s *TenantLLMService) EnsureTenantModelIDForParams(tenantID string, params map[string]interface{}) map[string]interface{} {
|
||||
paramKeys := []string{"llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"}
|
||||
|
||||
for _, key := range paramKeys {
|
||||
tenantKey := "tenant_" + key
|
||||
|
||||
if value, exists := params[key]; exists && value != nil && value != "" {
|
||||
if _, tenantExists := params[tenantKey]; !tenantExists {
|
||||
modelName, ok := value.(string)
|
||||
if !ok || modelName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
tenantLLM, err := s.GetAPIKey(tenantID, modelName)
|
||||
if err == nil && tenantLLM != nil {
|
||||
params[tenantKey] = tenantLLM.ID
|
||||
} else {
|
||||
params[tenantKey] = int64(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// GetTenantList get tenant list for a user
|
||||
func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) {
|
||||
tenants, err := s.userTenantDAO.GetTenantsByUserID(userID)
|
||||
|
||||
@ -924,6 +924,95 @@ func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserTenantService user tenant service
|
||||
// Provides business logic for user-tenant relationship management
|
||||
type UserTenantService struct {
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
}
|
||||
|
||||
// NewUserTenantService creates a new UserTenantService instance
|
||||
/**
|
||||
* Returns:
|
||||
* - *UserTenantService: a new UserTenantService instance
|
||||
*
|
||||
* Example:
|
||||
*
|
||||
* service := NewUserTenantService()
|
||||
* relations, err := service.GetUserTenantRelationByUserID("user123")
|
||||
*/
|
||||
func NewUserTenantService() *UserTenantService {
|
||||
return &UserTenantService{
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
}
|
||||
}
|
||||
|
||||
// UserTenantRelation represents a user-tenant relationship response
|
||||
// This structure matches the Python implementation's return format
|
||||
type UserTenantRelation struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// GetUserTenantRelationByUserID retrieves all user-tenant relationships for a given user ID
|
||||
/**
|
||||
* This method returns a list of user-tenant relationships with selected fields:
|
||||
* - id: the relationship ID
|
||||
* - user_id: the user ID
|
||||
* - tenant_id: the tenant ID
|
||||
* - role: the user's role in the tenant
|
||||
*
|
||||
* Parameters:
|
||||
* - userID: the unique identifier of the user
|
||||
*
|
||||
* Returns:
|
||||
* - []*UserTenantRelation: list of user-tenant relationships
|
||||
* - error: error if the operation fails, nil otherwise
|
||||
*
|
||||
* Example:
|
||||
*
|
||||
* service := NewUserTenantService()
|
||||
* relations, err := service.GetUserTenantRelationByUserID("user123")
|
||||
* if err != nil {
|
||||
* log.Printf("Failed to get user tenant relations: %v", err)
|
||||
* return
|
||||
* }
|
||||
* for _, rel := range relations {
|
||||
* fmt.Printf("User %s has role %s in tenant %s\n", rel.UserID, rel.Role, rel.TenantID)
|
||||
* }
|
||||
*/
|
||||
func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*UserTenantRelation, error) {
|
||||
relations, err := s.userTenantDAO.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*UserTenantRelation, len(relations))
|
||||
for i, rel := range relations {
|
||||
result[i] = convertToUserTenantRelation(rel)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// convertToUserTenantRelation converts model.UserTenant to UserTenantRelation
|
||||
/**
|
||||
* Parameters:
|
||||
* - userTenant: the model.UserTenant to convert
|
||||
*
|
||||
* Returns:
|
||||
* - *UserTenantRelation: the converted UserTenantRelation
|
||||
*/
|
||||
func convertToUserTenantRelation(userTenant *model.UserTenant) *UserTenantRelation {
|
||||
return &UserTenantRelation{
|
||||
ID: userTenant.ID,
|
||||
UserID: userTenant.UserID,
|
||||
TenantID: userTenant.TenantID,
|
||||
Role: userTenant.Role,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserByAPIToken gets user by access key from Authorization header
|
||||
// This is used for API token authentication
|
||||
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
|
||||
@ -963,4 +1052,5 @@ func (s *UserService) GetUserByAPIToken(authorization string) (*model.User, comm
|
||||
}
|
||||
|
||||
return user, common.CodeSuccess, nil
|
||||
|
||||
}
|
||||
|
||||
@ -1 +1,2 @@
|
||||
VITE_BASE_URL='/'
|
||||
VITE_BASE_URL='/'
|
||||
API_PROXY_SCHEME='python'
|
||||
@ -49,66 +49,18 @@ export default defineConfig(({ mode }) => {
|
||||
},
|
||||
},
|
||||
hybrid: {
|
||||
'/v1/system/token_list': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/system/new_token': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/system/token': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/system/config': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/user/login': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/v1/user/logout': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin/sandbox': {
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin/roles': {
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin/roles/owner/permission': {
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin/roles_with_permission': {
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin/whitelist': {
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin/variables': {
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'^(/api/v1/memories)|^(/v1/user/info)|^(/v1/user/tenant_info)|^(/v1/tenant/list)|^(/v1/system/config)|^(/v1/user/login)|^(/v1/user/logout)':
|
||||
{
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'^(/api/v1/admin/sandbox)|^(/api/v1/admin/roles)|^(/api/v1/admin/roles/owner/permission)|^(/api/v1/admin/roles_with_permission)|^(/api/v1/admin/whitelist)|^(/api/v1/admin/variables)':
|
||||
{
|
||||
target: 'http://127.0.0.1:9381/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api/v1/admin': {
|
||||
target: 'http://127.0.0.1:9383/',
|
||||
changeOrigin: true,
|
||||
|
||||
Reference in New Issue
Block a user