Compare commits

...

11 Commits

41 changed files with 2580 additions and 202 deletions

View File

@ -435,13 +435,17 @@ func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
milvusAddr := os.Getenv("MILVUS_ADDR")
user := os.Getenv("MILVUS_USER")
password := os.Getenv("MILVUS_PASSWORD")
var (
milvusAddr = os.Getenv("MILVUS_ADDR")
user = os.Getenv("MILVUS_USER")
password = os.Getenv("MILVUS_PASSWORD")
milvusToken = os.Getenv("MILVUS_TOKEN")
)
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
Address: milvusAddr,
Username: user,
Password: password,
APIKey: milvusToken,
})
if err != nil {
return nil, fmt.Errorf("init milvus client failed, err=%w", err)

View File

@ -21,6 +21,7 @@ import (
"encoding/json"
"errors"
"io"
"slices"
"strconv"
"github.com/cloudwego/eino/schema"
@ -102,7 +103,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *
return nil, err
}
if conData == nil {
return nil, errors.New("conversation data is nil")
return nil, errorx.New(errno.ErrConversationNotFound)
}
conversationData = conData
@ -110,7 +111,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *
}
if conversationData.CreatorID != userID {
return nil, errors.New("conversation data not match")
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg","user not match"))
}
return conversationData, nil
@ -138,26 +139,31 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a
if err != nil {
return nil, err
}
multiContent, contentType, err := a.buildMultiContent(ctx, ar)
multiAdditionalMessages, err := a.parseAdditionalMessages(ctx, ar)
if err != nil {
return nil, err
}
filterMultiAdditionalMessages, multiContent, contentType, err := a.parseQueryContent(ctx, multiAdditionalMessages)
if err != nil {
return nil, err
}
displayContent := a.buildDisplayContent(ctx, ar)
arm := &entity.AgentRunMeta{
ConversationID: ptr.From(ar.ConversationID),
AgentID: ar.BotID,
Content: multiContent,
DisplayContent: displayContent,
SpaceID: spaceID,
UserID: ar.User,
SectionID: conversationData.SectionID,
PreRetrieveTools: shortcutCMDData,
IsDraft: false,
ConnectorID: connectorID,
ContentType: contentType,
Ext: ar.ExtraParams,
CustomVariables: ar.CustomVariables,
CozeUID: conversationData.CreatorID,
ConversationID: ptr.From(ar.ConversationID),
AgentID: ar.BotID,
Content: multiContent,
DisplayContent: displayContent,
SpaceID: spaceID,
UserID: ar.User,
SectionID: conversationData.SectionID,
PreRetrieveTools: shortcutCMDData,
IsDraft: false,
ConnectorID: connectorID,
ContentType: contentType,
Ext: ar.ExtraParams,
CustomVariables: ar.CustomVariables,
CozeUID: conversationData.CreatorID,
AdditionalMessages: filterMultiAdditionalMessages,
}
return arm, nil
}
@ -200,29 +206,68 @@ func (a *OpenapiAgentRunApplication) buildDisplayContent(_ context.Context, ar *
return ""
}
func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *run.ChatV3Request) ([]*message.InputMetaData, message.ContentType, error) {
var multiContents []*message.InputMetaData
contentType := message.ContentTypeText
func (a *OpenapiAgentRunApplication) parseQueryContent(ctx context.Context, multiAdditionalMessages []*entity.AdditionalMessage) ([]*entity.AdditionalMessage, []*message.InputMetaData, message.ContentType, error) {
var multiContent []*message.InputMetaData
var contentType message.ContentType
var filterMultiAdditionalMessages []*entity.AdditionalMessage
filterMultiAdditionalMessages = multiAdditionalMessages
if len(multiAdditionalMessages) > 0 {
lastMessage := multiAdditionalMessages[len(multiAdditionalMessages)-1]
if lastMessage != nil && lastMessage.Role == schema.User {
multiContent = lastMessage.Content
contentType = lastMessage.ContentType
filterMultiAdditionalMessages = multiAdditionalMessages[:len(multiAdditionalMessages)-1]
}
}
return filterMultiAdditionalMessages, multiContent, contentType, nil
}
func (a *OpenapiAgentRunApplication) parseAdditionalMessages(ctx context.Context, ar *run.ChatV3Request) ([]*entity.AdditionalMessage, error) {
additionalMessages := make([]*entity.AdditionalMessage, 0, len(ar.AdditionalMessages))
for _, item := range ar.AdditionalMessages {
if item == nil {
continue
}
if item.Role != string(schema.User) {
return nil, contentType, errors.New("role not match")
if item.Role != string(schema.User) && item.Role != string(schema.Assistant) {
return nil, errors.New("additional message role only support user and assistant")
}
if item.Type != nil && !slices.Contains([]message.MessageType{message.MessageTypeQuestion, message.MessageTypeAnswer}, message.MessageType(*item.Type)) {
return nil, errors.New("additional message type only support question and answer now")
}
addOne := entity.AdditionalMessage{
Role: schema.RoleType(item.Role),
}
if item.Type != nil {
addOne.Type = message.MessageType(*item.Type)
} else {
addOne.Type = message.MessageTypeQuestion
}
if item.ContentType == run.ContentTypeText {
if item.Content == "" {
continue
}
multiContents = append(multiContents, &message.InputMetaData{
addOne.ContentType = message.ContentTypeText
addOne.Content = []*message.InputMetaData{{
Type: message.InputTypeText,
Text: item.Content,
})
}}
}
if item.ContentType == run.ContentTypeMixApi {
contentType = message.ContentTypeMix
if ptr.From(item.Type) == string(message.MessageTypeAnswer) {
return nil, errors.New(" answer messages only support text content")
}
addOne.ContentType = message.ContentTypeMix
var inputs []*run.AdditionalContent
err := json.Unmarshal([]byte(item.Content), &inputs)
@ -236,7 +281,8 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *
}
switch message.InputType(one.Type) {
case message.InputTypeText:
multiContents = append(multiContents, &message.InputMetaData{
addOne.Content = append(addOne.Content, &message.InputMetaData{
Type: message.InputTypeText,
Text: ptr.From(one.Text),
})
@ -250,12 +296,12 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *
ID: one.GetFileID(),
})
if err != nil {
return nil, contentType, err
return nil, err
}
fileUrl = fileInfo.File.Url
fileURI = fileInfo.File.TosURI
}
multiContents = append(multiContents, &message.InputMetaData{
addOne.Content = append(addOne.Content, &message.InputMetaData{
Type: message.InputType(one.Type),
FileData: []*message.FileData{
{
@ -269,10 +315,10 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *
}
}
}
additionalMessages = append(additionalMessages, &addOne)
}
return multiContents, contentType, nil
return additionalMessages, nil
}
func (a *OpenapiAgentRunApplication) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, streamer *schema.StreamReader[*entity.AgentRunResponse]) {

View File

@ -0,0 +1,903 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
openapiEntity "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
uploadEntity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
uploadService "github.com/coze-dev/coze-studio/backend/domain/upload/service"
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
mockSingleAgent "github.com/coze-dev/coze-studio/backend/internal/mock/domain/agent/singleagent"
mockAgentRun "github.com/coze-dev/coze-studio/backend/internal/mock/domain/conversation/agentrun"
mockConversation "github.com/coze-dev/coze-studio/backend/internal/mock/domain/conversation/conversation"
mockShortcut "github.com/coze-dev/coze-studio/backend/internal/mock/domain/shortcutcmd"
mockUpload "github.com/coze-dev/coze-studio/backend/internal/mock/domain/upload"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
func setupMocks(t *testing.T) (*OpenapiAgentRunApplication, *mockShortcut.MockShortcutCmd, *mockUpload.MockUploadService, *mockAgentRun.MockRun, *mockConversation.MockConversation, *mockSingleAgent.MockSingleAgent, *gomock.Controller) {
ctrl := gomock.NewController(t)
mockShortcutSvc := mockShortcut.NewMockShortcutCmd(ctrl)
mockUploadSvc := mockUpload.NewMockUploadService(ctrl)
mockAgentRunSvc := mockAgentRun.NewMockRun(ctrl)
mockConversationSvc := mockConversation.NewMockConversation(ctrl)
mockSingleAgentSvc := mockSingleAgent.NewMockSingleAgent(ctrl)
app := &OpenapiAgentRunApplication{
ShortcutDomainSVC: mockShortcutSvc,
UploaodDomainSVC: mockUploadSvc,
}
// Setup ConversationSVC mocks
originalConversationSVC := ConversationSVC
ConversationSVC = &ConversationApplicationService{
AgentRunDomainSVC: mockAgentRunSvc,
ConversationDomainSVC: mockConversationSvc,
appContext: &ServiceComponents{
SingleAgentDomainSVC: mockSingleAgentSvc,
},
}
t.Cleanup(func() {
ConversationSVC = originalConversationSVC
ctrl.Finish()
})
return app, mockShortcutSvc, mockUploadSvc, mockAgentRunSvc, mockConversationSvc, mockSingleAgentSvc, ctrl
}
func createTestContext() context.Context {
ctx := context.Background()
ctx = ctxcache.Init(ctx)
apiKey := &openapiEntity.ApiKey{
UserID: 12345,
ConnectorID: consts.CozeConnectorID,
}
ctxcache.Store(ctx, consts.OpenapiAuthKeyInCtx, apiKey)
return ctx
}
func createTestRequest() *run.ChatV3Request {
return &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: "Hello, world!",
ContentType: run.ContentTypeText,
},
},
}
}
func createTestRequestWithMultipleMessages() *run.ChatV3Request {
return &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: "Hello, I need help with something.",
ContentType: run.ContentTypeText,
},
{
Role: "assistant",
Content: "Sure, I'd be happy to help! What do you need assistance with?",
ContentType: run.ContentTypeText,
},
{
Role: "user",
Content: `{"type": "image", "url": "https://example.com/image.jpg"}`,
ContentType: run.ContentTypeImage,
},
{
Role: "user",
Content: `{"type": "file", "name": "document.pdf", "url": "https://example.com/doc.pdf"}`,
ContentType: run.ContentTypeFile,
},
},
}
}
func createTestRequestWithAssistantOnly() *run.ChatV3Request {
return &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "assistant",
Content: "I'm here to help you with any questions you might have.",
ContentType: run.ContentTypeText, // assistant role only supports text content type
},
},
}
}
func TestOpenapiAgentRun_Success(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
// Mock agent check
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
}
func TestOpenapiAgentRun_CheckAgentError(t *testing.T) {
app, _, _, _, _, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
// Mock agent check failure
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(nil, errors.New("agent not found"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
func TestOpenapiAgentRun_AgentNotExists(t *testing.T) {
app, _, _, _, _, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
// Mock agent check returns nil (agent not exists)
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(nil, nil)
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
}
func TestOpenapiAgentRun_CheckConversationError(t *testing.T) {
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check failure
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(nil, errors.New("conversation not found"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "conversation not found")
}
func TestOpenapiAgentRun_ConversationPermissionError(t *testing.T) {
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation with different creator
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 99999, // Different from user ID (12345)
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
}
func TestOpenapiAgentRun_CreateNewConversation(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
req.ConversationID = ptr.Of(int64(0)) // No conversation ID
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock create new conversation
mockConv := &convEntity.Conversation{
ID: 22222,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().Create(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, meta *convEntity.CreateMeta) (*convEntity.Conversation, error) {
assert.Equal(t, int64(67890), meta.AgentID)
assert.Equal(t, int64(12345), meta.UserID)
assert.Equal(t, common.Scene_SceneOpenApi, meta.Scene)
return mockConv, nil
})
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Equal(t, int64(22222), *req.ConversationID) // Should be updated
}
func TestOpenapiAgentRun_AgentRunError(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("agent run failed"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "agent run failed")
}
func TestOpenapiAgentRun_WithShortcutCommand(t *testing.T) {
app, mockShortcut, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequest()
req.ShortcutCommand = &run.ShortcutCommandDetail{
CommandID: 123,
Parameters: map[string]string{"param1": "value1"},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock shortcut command
mockCmd := &cmdEntity.ShortcutCmd{
ID: 123,
PluginID: 456,
PluginToolName: "test-tool",
PluginToolID: 789,
ToolType: 1,
}
mockShortcut.EXPECT().GetByCmdID(ctx, int64(123), int32(0)).Return(mockCmd, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
}
func TestOpenapiAgentRun_WithMultipleMessages(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequestWithMultipleMessages()
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
// Verify that the request contains multiple messages with different roles and content types
assert.Len(t, req.AdditionalMessages, 4)
assert.Equal(t, "user", req.AdditionalMessages[0].Role)
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[0].ContentType)
assert.Equal(t, "assistant", req.AdditionalMessages[1].Role)
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[1].ContentType)
assert.Equal(t, "user", req.AdditionalMessages[2].Role)
assert.Equal(t, run.ContentTypeImage, req.AdditionalMessages[2].ContentType)
assert.Equal(t, "user", req.AdditionalMessages[3].Role)
assert.Equal(t, run.ContentTypeFile, req.AdditionalMessages[3].ContentType)
}
func TestOpenapiAgentRun_WithAssistantMessage(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
req := createTestRequestWithAssistantOnly()
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
// Verify that the assistant message only supports text content type
assert.Len(t, req.AdditionalMessages, 1)
assert.Equal(t, "assistant", req.AdditionalMessages[0].Role)
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[0].ContentType)
}
func TestOpenapiAgentRun_WithMixedContentTypes(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with various content types for user role
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: "Here's a text message",
ContentType: run.ContentTypeText,
},
{
Role: "user",
Content: `{"type": "audio", "url": "https://example.com/audio.mp3"}`,
ContentType: run.ContentTypeAudio,
},
{
Role: "user",
Content: `{"type": "video", "url": "https://example.com/video.mp4"}`,
ContentType: run.ContentTypeVideo,
},
{
Role: "assistant",
Content: "I can only respond with text content.",
ContentType: run.ContentTypeText, // assistant must use text
},
{
Role: "user",
Content: `{"type": "link", "url": "https://example.com"}`,
ContentType: run.ContentTypeLink,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
// Verify various content types are preserved
assert.Len(t, req.AdditionalMessages, 5)
// Check user messages with different content types
assert.Equal(t, "user", req.AdditionalMessages[0].Role)
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[0].ContentType)
assert.Equal(t, "user", req.AdditionalMessages[1].Role)
assert.Equal(t, run.ContentTypeAudio, req.AdditionalMessages[1].ContentType)
assert.Equal(t, "user", req.AdditionalMessages[2].Role)
assert.Equal(t, run.ContentTypeVideo, req.AdditionalMessages[2].ContentType)
// Check assistant message (must be text)
assert.Equal(t, "assistant", req.AdditionalMessages[3].Role)
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[3].ContentType)
assert.Equal(t, "user", req.AdditionalMessages[4].Role)
assert.Equal(t, run.ContentTypeLink, req.AdditionalMessages[4].ContentType)
}
func TestOpenapiAgentRun_ParseAdditionalMessages_InvalidRole(t *testing.T) {
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with invalid role
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "system", // Invalid role
Content: "System message",
ContentType: run.ContentTypeText,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success to reach parseAdditionalMessages
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "additional message role only support user and assistant")
}
func TestOpenapiAgentRun_ParseAdditionalMessages_InvalidType(t *testing.T) {
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with invalid message type
invalidType := "invalid_type"
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: "Test message",
ContentType: run.ContentTypeText,
Type: &invalidType, // Invalid type
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success to reach parseAdditionalMessages
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "additional message type only support question and answer now")
}
func TestOpenapiAgentRun_ParseAdditionalMessages_AnswerWithNonTextContent(t *testing.T) {
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with answer type but non-text content
answerType := "answer"
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "assistant",
Content: `[{"type": "image", "file_url": "https://example.com/image.jpg"}]`,
ContentType: run.ContentTypeMixApi, // object_string
Type: &answerType,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success to reach parseAdditionalMessages
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "answer messages only support text content")
}
func TestOpenapiAgentRun_ParseAdditionalMessages_MixApiWithFileURL(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with object_string content type and file URL
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: `[{"type": "text", "text": "Here's an image:"}, {"type": "image", "file_url": "https://example.com/image.jpg"}]`,
ContentType: run.ContentTypeMixApi,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
}
func TestOpenapiAgentRun_ParseAdditionalMessages_MixApiWithFileID(t *testing.T) {
app, _, mockUpload, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with object_string content type and file ID
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: `[{"type": "file", "file_id": "12345"}]`,
ContentType: run.ContentTypeMixApi,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock upload service to return file info
mockUpload.EXPECT().GetFile(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, req *uploadService.GetFileRequest) (*uploadService.GetFileResponse, error) {
assert.Equal(t, int64(12345), req.ID)
return &uploadService.GetFileResponse{
File: &uploadEntity.File{
Url: "https://example.com/file.pdf",
TosURI: "tos://bucket/file.pdf",
},
}, nil
})
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
}
func TestOpenapiAgentRun_ParseAdditionalMessages_FileIDError(t *testing.T) {
app, _, mockUpload, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with object_string content type and file ID that will fail
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: `[{"type": "file", "file_id": "99999"}]`,
ContentType: run.ContentTypeMixApi,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock upload service to return error
mockUpload.EXPECT().GetFile(ctx, gomock.Any()).Return(nil, errors.New("file not found"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
}
func TestOpenapiAgentRun_ParseAdditionalMessages_EmptyContent(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with empty text content (should be skipped)
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: "", // Empty content
ContentType: run.ContentTypeText,
},
{
Role: "user",
Content: "Valid content",
ContentType: run.ContentTypeText,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
// Verify that only the non-empty message is included
assert.Len(t, req.AdditionalMessages, 2) // Original request still has 2 messages
}
func TestOpenapiAgentRun_ParseAdditionalMessages_NilMessage(t *testing.T) {
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
ctx := createTestContext()
// Create request with empty content message (should be skipped)
req := &run.ChatV3Request{
BotID: 67890,
ConversationID: ptr.Of(int64(11111)),
User: "test-user",
AdditionalMessages: []*run.EnterMessage{
{
Role: "user",
Content: "", // Empty content message
ContentType: run.ContentTypeText,
},
{
Role: "user",
Content: "Valid content",
ContentType: run.ContentTypeText,
},
},
}
// Mock agent check success
mockAgent := &saEntity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
AgentID: 67890,
SpaceID: 54321,
},
}
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
// Mock conversation check success
mockConv := &convEntity.Conversation{
ID: 11111,
CreatorID: 12345,
SectionID: 98765,
}
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
// Mock agent run failure to avoid pullStream complexity
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock stream error")
}

View File

@ -1322,10 +1322,10 @@ func mergeBatchModeNodes(parent, inner *workflow.NodeResult) *workflow.NodeResul
type StreamRunEventType string
const (
DoneEvent StreamRunEventType = "done"
MessageEvent StreamRunEventType = "message"
ErrEvent StreamRunEventType = "error"
InterruptEvent StreamRunEventType = "interrupt"
DoneEvent StreamRunEventType = "Done"
MessageEvent StreamRunEventType = "Message"
ErrEvent StreamRunEventType = "Error"
InterruptEvent StreamRunEventType = "Interrupt"
)
func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *workflow.OpenAPIStreamRunFlowResponse, err error) {

View File

@ -30,6 +30,7 @@ type Message interface {
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error)
PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error)
Create(ctx context.Context, msg *message.Message) (*message.Message, error)
BatchCreate(ctx context.Context, msg []*message.Message) ([]*message.Message, error)
List(ctx context.Context, meta *entity.ListMeta) (*entity.ListResult, error)
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
Edit(ctx context.Context, msg *message.Message) (*message.Message, error)

View File

@ -59,6 +59,21 @@ func (m *MockMessage) EXPECT() *MockMessageMockRecorder {
return m.recorder
}
// BatchCreate mocks base method.
func (m *MockMessage) BatchCreate(ctx context.Context, msg []*message.Message) ([]*message.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchCreate", ctx, msg)
ret0, _ := ret[0].([]*message.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BatchCreate indicates an expected call of BatchCreate.
func (mr *MockMessageMockRecorder) BatchCreate(ctx, msg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreate", reflect.TypeOf((*MockMessage)(nil).BatchCreate), ctx, msg)
}
// Create mocks base method.
func (m *MockMessage) Create(ctx context.Context, msg *message.Message) (*message.Message, error) {
m.ctrl.T.Helper()

View File

@ -170,6 +170,9 @@ func (c *impl) GetMessageByID(ctx context.Context, id int64) (*entity.Message, e
func (c *impl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
return c.DomainSVC.ListWithoutPair(ctx, req)
}
func (c *impl) BatchCreate(ctx context.Context, msgs []*entity.Message) ([]*entity.Message, error) {
return c.DomainSVC.BatchCreate(ctx, msgs)
}
func convertToConvAndSchemaMessage(ctx context.Context, msgs []*entity.Message) ([]*crossmessage.WfMessage, []*schema.Message, error) {
messages := make([]*schema.Message, 0)

View File

@ -25,6 +25,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
)
//go:generate mockgen -destination ../../../../internal/mock/domain/agent/singleagent/single_agent_mock.go --package singleagent -source single_agent.go
type SingleAgent interface {
// draft agent
CreateSingleAgentDraft(ctx context.Context, creatorID int64, draft *entity.SingleAgent) (agentID int64, err error)

View File

@ -106,24 +106,34 @@ type MetaInfo struct {
}
type AgentRunMeta struct {
ConversationID int64 `json:"conversation_id"`
ConnectorID int64 `json:"connector_id"`
SpaceID int64 `json:"space_id"`
Scene common.Scene `json:"scene"`
SectionID int64 `json:"section_id"`
Name string `json:"name"`
UserID string `json:"user_id"`
CozeUID int64 `json:"coze_uid"`
AgentID int64 `json:"agent_id"`
ContentType message.ContentType `json:"content_type"`
Content []*message.InputMetaData `json:"content"`
PreRetrieveTools []*Tool `json:"tools"`
IsDraft bool `json:"is_draft"`
CustomerConfig *CustomerConfig `json:"customer_config"`
DisplayContent string `json:"display_content"`
CustomVariables map[string]string `json:"custom_variables"`
Version string `json:"version"`
Ext map[string]string `json:"ext"`
ConversationID int64 `json:"conversation_id"`
ConnectorID int64 `json:"connector_id"`
SpaceID int64 `json:"space_id"`
Scene common.Scene `json:"scene"`
SectionID int64 `json:"section_id"`
Name string `json:"name"`
UserID string `json:"user_id"`
CozeUID int64 `json:"coze_uid"`
AgentID int64 `json:"agent_id"`
ContentType message.ContentType `json:"content_type"`
Content []*message.InputMetaData `json:"content"`
PreRetrieveTools []*Tool `json:"tools"`
IsDraft bool `json:"is_draft"`
CustomerConfig *CustomerConfig `json:"customer_config"`
DisplayContent string `json:"display_content"`
CustomVariables map[string]string `json:"custom_variables"`
Version string `json:"version"`
Ext map[string]string `json:"ext"`
AdditionalMessages []*AdditionalMessage `json:"additional_messages"`
}
type AdditionalMessage struct {
Role schema.RoleType `json:"role"`
Type message.MessageType `json:"type"`
Content []*message.InputMetaData `json:"content"`
ContentType message.ContentType `json:"content_type"`
Name *string `json:"name"`
Meta map[string]string `json:"meta"`
}
type UpdateMeta struct {

View File

@ -41,7 +41,7 @@ import (
func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) (err error) {
mh := &MesssageEventHanlder{
mh := &MessageEventHandler{
sw: art.SW,
messageEvent: art.MessageEvent,
}
@ -110,7 +110,7 @@ func concatWfInput(rtDependence *AgentRuntime) string {
return strings.Trim(input, ",")
}
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MesssageEventHanlder) {
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MessageEventHandler) {
fullAnswerContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt

View File

@ -221,6 +221,51 @@ func preCreateAnswer(ctx context.Context, rtDependence *AgentRuntime) (*msgEntit
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
}
func buildAdditionalMessage2Create(ctx context.Context, runRecord *entity.RunRecordMeta, additionalMessage *entity.AdditionalMessage, userID string) *message.Message {
msg := &msgEntity.Message{
ConversationID: runRecord.ConversationID,
RunID: runRecord.ID,
AgentID: runRecord.AgentID,
SectionID: runRecord.SectionID,
UserID: userID,
MessageType: additionalMessage.Type,
}
switch additionalMessage.Type {
case message.MessageTypeQuestion:
msg.Role = schema.User
msg.ContentType = additionalMessage.ContentType
for _, content := range additionalMessage.Content {
if content.Type == message.InputTypeText {
msg.Content = content.Text
break
}
}
msg.MultiContent = additionalMessage.Content
case message.MessageTypeAnswer:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
for _, content := range additionalMessage.Content {
if content.Type == message.InputTypeText {
msg.Content = content.Text
break
}
}
modelContent := &schema.Message{
Role: schema.Assistant,
Content: msg.Content,
}
jsonContent, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(jsonContent)
}
}
return msg
}
func buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *AgentRuntime) *message.Message {
arm := rtDependence.GetRunMeta()
msg := &msgEntity.Message{

View File

@ -98,12 +98,12 @@ func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResp
sw.Send(resp, nil)
}
type MesssageEventHanlder struct {
type MessageEventHandler struct {
messageEvent *Event
sw *schema.StreamWriter[*entity.AgentRunResponse]
}
func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
func (mh *MessageEventHandler) handlerErr(_ context.Context, err error) {
var errMsg string
var statusErr errorx.StatusError
@ -123,7 +123,7 @@ func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
})
}
func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
func (mh *MessageEventHandler) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
sendMsg := &entity.ChunkMessageItem{
ID: input.ID,
ConversationID: input.ConversationID,
@ -142,7 +142,7 @@ func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgE
return nil
}
func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
func (mh *MessageEventHandler) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
@ -156,7 +156,7 @@ func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *
return nil
}
func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
func (mh *MessageEventHandler) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
@ -184,7 +184,7 @@ func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *e
return nil
}
func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
func (mh *MessageEventHandler) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
@ -199,7 +199,7 @@ func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entit
return nil
}
func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
func (mh *MessageEventHandler) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
@ -212,7 +212,7 @@ func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *ent
return nil
}
func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
func (mh *MessageEventHandler) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
return nil
@ -265,7 +265,7 @@ func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.C
return nil
}
func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
func (mh *MessageEventHandler) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
@ -278,7 +278,7 @@ func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rt
return nil
}
func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
func (mh *MessageEventHandler) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
@ -291,7 +291,7 @@ func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chu
return nil
}
func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
func (mh *MessageEventHandler) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
if msg.Ext == nil {
msg.Ext = map[string]string{}
@ -314,7 +314,7 @@ func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.
return nil
}
func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
func (mh *MessageEventHandler) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
interruptData, cType, err := parseInterruptData(ctx, chunk.Interrupt)
if err != nil {
return err
@ -366,7 +366,7 @@ func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *ent
return nil
}
func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
func (mh *MessageEventHandler) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
interruptData, cType, err := handlerWfInterruptEvent(ctx, stateMsg.InterruptEvent)
if err != nil {
return
@ -412,7 +412,7 @@ func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, state
}
}
func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
func (mh *MessageEventHandler) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
msgMeta := buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
@ -426,3 +426,21 @@ func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *
}
return cm, nil
}
func (mh *MessageEventHandler) ParseAdditionalMessages(ctx context.Context, rtDependence *AgentRuntime, runRecord *entity.RunRecordMeta) error {
if len(rtDependence.GetRunMeta().AdditionalMessages) == 0 {
return nil
}
additionalMessages := make([]*message.Message, 0, len(rtDependence.GetRunMeta().AdditionalMessages))
for _, msg := range rtDependence.GetRunMeta().AdditionalMessages {
cm := buildAdditionalMessage2Create(ctx, runRecord, msg, rtDependence.GetRunMeta().UserID)
additionalMessages = append(additionalMessages, cm)
}
_, err := crossmessage.DefaultSVC().BatchCreate(ctx, additionalMessages)
return err
}

View File

@ -107,6 +107,11 @@ func (rd *AgentRuntime) GetHistory() []*msgEntity.Message {
func (art *AgentRuntime) Run(ctx context.Context) (err error) {
mh := &MessageEventHandler{
messageEvent: art.MessageEvent,
sw: art.SW,
}
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft, art.GetRunMeta().ConnectorID)
if err != nil {
return
@ -114,6 +119,18 @@ func (art *AgentRuntime) Run(ctx context.Context) (err error) {
art.SetAgentInfo(agentInfo)
if len(art.GetRunMeta().AdditionalMessages) > 0 {
var additionalRunRecord *entity.RunRecordMeta
additionalRunRecord, err = art.RunRecordRepo.Create(ctx, art.GetRunMeta())
if err != nil {
return
}
err = mh.ParseAdditionalMessages(ctx, art, additionalRunRecord)
if err != nil {
return
}
}
history, err := art.getHistory(ctx)
if err != nil {
return
@ -140,10 +157,7 @@ func (art *AgentRuntime) Run(ctx context.Context) (err error) {
}
art.RunProcess.StepToComplete(ctx, srRecord, art.SW, art.GetUsage())
}()
mh := &MesssageEventHanlder{
messageEvent: art.MessageEvent,
sw: art.SW,
}
input, err := mh.HandlerInput(ctx, art)
if err != nil {
return

View File

@ -80,7 +80,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I
func (art *AgentRuntime) push(ctx context.Context, mainChan chan *entity.AgentRespEvent) {
mh := &MesssageEventHanlder{
mh := &MessageEventHandler{
sw: art.SW,
messageEvent: art.MessageEvent,
}

View File

@ -24,6 +24,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
)
//go:generate mockgen -destination ../../../../internal/mock/domain/conversation/agentrun/agent_run_mock.go --package agentrun -source agent_run.go
type Run interface {
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
Delete(ctx context.Context, runID []int64) error

View File

@ -22,6 +22,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
)
//go:generate mockgen -destination ../../../../internal/mock/domain/conversation/conversation/conversation_mock.go --package conversation -source conversation.go
type Conversation interface {
Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error)
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)

View File

@ -72,6 +72,25 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity
return dao.messagePO2DO(poData), nil
}
func (dao *MessageDAO) BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error) {
poList := make([]*model.Message, 0, len(msg))
for _, m := range msg {
po, err := dao.messageDO2PO(ctx, m)
if err != nil {
return nil, err
}
poList = append(poList, po)
}
do := dao.query.Message.WithContext(ctx).Debug()
cErr := do.CreateInBatches(poList, len(poList))
if cErr != nil {
return nil, cErr
}
return dao.batchMessagePO2DO(poList), nil
}
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))

View File

@ -34,6 +34,7 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
type MessageRepo interface {
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error)
List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error)
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)

View File

@ -27,6 +27,7 @@ type Message interface {
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
BatchCreate(ctx context.Context, req []*entity.Message) ([]*entity.Message, error)
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
GetByID(ctx context.Context, id int64) (*entity.Message, error)
Edit(ctx context.Context, req *entity.Message) (*entity.Message, error)

View File

@ -124,6 +124,10 @@ func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, e
return m.MessageRepo.GetByID(ctx, id)
}
func (m *messageImpl) BatchCreate(ctx context.Context, req []*entity.Message) ([]*entity.Message, error) {
return m.MessageRepo.BatchCreate(ctx, req)
}
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {
_, err := m.MessageRepo.Edit(ctx, req.ID, &message.Message{

View File

@ -494,3 +494,55 @@ func TestListWithoutPair(t *testing.T) {
assert.Equal(t, "Answer message", resp.Messages[0].Content)
})
}
func TestBatchCreate(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{})
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
t.Run("success_single_message", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// 准备测试数据
inputMsgs := []*entity.Message{
{
ID: 1,
ConversationID: 100,
RunID: 200,
AgentID: 300,
UserID: "user123",
Content: "Hello World",
Role: schema.User,
ContentType: message.ContentTypeText,
MessageType: message.MessageTypeQuestion,
Status: message.MessageStatusAvailable,
},
{
ID: 2,
ConversationID: 100,
RunID: 200,
AgentID: 300,
UserID: "user123",
Content: "Hello World",
Role: schema.Assistant,
ContentType: message.ContentTypeText,
MessageType: message.MessageTypeQuestion,
Status: message.MessageStatusAvailable,
},
}
result, err := NewService(components).BatchCreate(ctx, inputMsgs)
assert.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, inputMsgs[1].ID, result[1].ID)
})
}

View File

@ -1209,32 +1209,10 @@ func (d databaseService) executeSelectSQL(ctx context.Context, req *ExecuteSQLRe
selectReq.Fields = fields
}
var complexCond *rdb.ComplexCondition
var err error
if req.Condition != nil {
complexCond, err = convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
if err != nil {
return nil, fmt.Errorf("convert condition failed: %v", err)
}
complexCond, err := generateComplexCond(ctx, req, tableInfo.RwMode, fieldNameToPhysical)
if err != nil {
return nil, err
}
// add rw mode
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
cond := &rdb.Condition{
Field: database.DefaultUidColName,
Operator: entity3.OperatorEqual,
Value: req.UserID,
}
if complexCond == nil {
complexCond = &rdb.ComplexCondition{
Conditions: []*rdb.Condition{cond},
}
} else {
complexCond.Conditions = append(complexCond.Conditions, cond)
}
}
if complexCond != nil {
selectReq.Where = complexCond
}
@ -1376,27 +1354,10 @@ func (d databaseService) executeUpdateSQL(ctx context.Context, req *ExecuteSQLRe
}
}
condParams := req.SQLParams[index:]
complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, condParams)
req.SQLParams = req.SQLParams[index:]
complexCond, err := generateComplexCond(ctx, req, tableInfo.RwMode, fieldNameToPhysical)
if err != nil {
return -1, fmt.Errorf("convert condition failed: %v", err)
}
// add rw mode
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
cond := &rdb.Condition{
Field: database.DefaultUidColName,
Operator: entity3.OperatorEqual,
Value: req.UserID,
}
if complexCond == nil {
complexCond = &rdb.ComplexCondition{
Conditions: []*rdb.Condition{cond},
}
} else {
complexCond.Conditions = append(complexCond.Conditions, cond)
}
return -1, err
}
updateResp, err := d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{
@ -1417,26 +1378,9 @@ func (d databaseService) executeDeleteSQL(ctx context.Context, req *ExecuteSQLRe
return -1, fmt.Errorf("missing delete condition")
}
complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
complexCond, err := generateComplexCond(ctx, req, tableInfo.RwMode, fieldNameToPhysical)
if err != nil {
return -1, fmt.Errorf("convert condition failed: %v", err)
}
// add rw mode
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
cond := &rdb.Condition{
Field: database.DefaultUidColName,
Operator: entity3.OperatorEqual,
Value: req.UserID,
}
if complexCond == nil {
complexCond = &rdb.ComplexCondition{
Conditions: []*rdb.Condition{cond},
}
} else {
complexCond.Conditions = append(complexCond.Conditions, cond)
}
return -1, err
}
deleteResp, err := d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
@ -1538,13 +1482,6 @@ func convertCondition(ctx context.Context, cond *database.ComplexCondition, fiel
}
result.Conditions = conditions
}
// if cond.NestedConditions != nil {
// nested, err := convertCondition(cond.NestedConditions, fieldMap, params)
// if err != nil {
// return nil, err
// }
// result.NestedConditions = []*rdb.ComplexCondition{nested}
// }
return result, nil
}
@ -2204,3 +2141,50 @@ func (d databaseService) GetAllDatabaseByAppID(ctx context.Context, req *GetAllD
Databases: onlineDBs,
}, nil
}
func generateComplexCond(ctx context.Context, req *ExecuteSQLRequest, mode table.BotTableRWMode, fieldNameToPhysical map[string]string) (*rdb.ComplexCondition, error) {
var (
err error
complexCond *rdb.ComplexCondition
extraCondition *rdb.ComplexCondition
)
if req.Condition != nil {
complexCond, err = convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
if err != nil {
return nil, fmt.Errorf("convert condition failed: %v", err)
}
}
if mode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
cond := &rdb.Condition{
Field: database.DefaultUidColName,
Operator: entity3.OperatorEqual,
Value: req.UserID,
}
extraCondition = &rdb.ComplexCondition{
Conditions: []*rdb.Condition{cond},
}
}
if complexCond != nil && extraCondition != nil {
return &rdb.ComplexCondition{
NestedConditions: []*rdb.ComplexCondition{
complexCond,
extraCondition,
},
Operator: entity3.AND,
}, nil
}
if complexCond != nil {
return complexCond, nil
}
if extraCondition != nil {
return extraCondition, nil
}
return nil, nil
}

View File

@ -634,6 +634,52 @@ func TestExecuteSQLWithOperations(t *testing.T) {
assert.NotNil(t, selectINResp)
assert.True(t, len(selectINResp.Records) == 2)
executeSelectWithOrOperationReq := &ExecuteSQLRequest{
DatabaseID: resp.Database.ID,
TableType: table.TableType_OnlineTable,
OperateType: database.OperateType_Select,
SelectFieldList: selectFields,
Limit: &limit,
UserID: "1001",
SpaceID: 1,
OrderByList: []database.OrderBy{
{
Field: "id_custom",
Direction: table.SortDirection_Desc,
},
},
SQLParams: []*database.SQLParamVal{
{
Value: ptr.Of("Alice"),
},
{
Value: ptr.Of("100"),
},
},
Condition: &database.ComplexCondition{
Conditions: []*database.Condition{
{
Left: "name",
Operation: database.Operation_EQUAL,
Right: "?",
},
{
Left: "score",
Operation: database.Operation_EQUAL,
Right: "?",
},
},
Logic: database.Logic_Or,
},
}
selectWithOrOperationResp, err := dbService.ExecuteSQL(context.Background(), executeSelectWithOrOperationReq)
assert.NoError(t, err)
assert.NotNil(t, selectWithOrOperationResp)
assert.Equal(t, string(selectWithOrOperationResp.Records[0]["name"].([]uint8)), "Alice")
assert.True(t, len(selectWithOrOperationResp.Records) == 1)
updateRows := []*database.UpsertRow{
{
Records: []*database.Record{

View File

@ -22,6 +22,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
)
//go:generate mockgen -destination ../../../internal/mock/domain/shortcutcmd/shortcut_cmd_mock.go --package shortcutcmd -source shortcut_cmd.go
type ShortcutCmd interface {
ListCMD(ctx context.Context, lm *entity.ListMeta) ([]*entity.ShortcutCmd, error)
CreateCMD(ctx context.Context, shortcut *entity.ShortcutCmd) (*entity.ShortcutCmd, error)

View File

@ -22,6 +22,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/upload/entity"
)
//go:generate mockgen -destination ../../../internal/mock/domain/upload/upload_service_mock.go --package upload -source interface.go
type UploadService interface {
UploadFile(ctx context.Context, req *UploadFileRequest) (resp *UploadFileResponse, err error)
UploadFiles(ctx context.Context, req *UploadFilesRequest) (resp *UploadFilesResponse, err error)

View File

@ -368,23 +368,7 @@ func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.Composite
continue
}
if _, ok := carryOvers[fromNodeKey]; !ok {
carryOvers[fromNodeKey] = make([]*compose.FieldMapping, 0)
}
for _, fm := range fieldMappings {
duplicate := false
for _, existing := range carryOvers[fromNodeKey] {
if fm.Equals(existing) {
duplicate = true
break
}
}
if !duplicate {
carryOvers[fromNodeKey] = append(carryOvers[fromNodeKey], fieldMappings...)
}
}
addFieldMappingsWithDeduplication(carryOvers, fromNodeKey, fieldMappings)
}
}
@ -882,3 +866,29 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
variableInfos: variableInfos,
}, nil
}
// addFieldMappingsWithDeduplication adds field mappings to carryOvers while avoiding duplicates
func addFieldMappingsWithDeduplication(
carryOvers map[vo.NodeKey][]*compose.FieldMapping,
fromNodeKey vo.NodeKey,
fieldMappings []*compose.FieldMapping,
) {
if _, ok := carryOvers[fromNodeKey]; !ok {
carryOvers[fromNodeKey] = make([]*compose.FieldMapping, 0)
}
for i := range fieldMappings {
fm := fieldMappings[i]
duplicate := false
for _, existing := range carryOvers[fromNodeKey] {
if fm.Equals(existing) {
duplicate = true
break
}
}
if !duplicate {
carryOvers[fromNodeKey] = append(carryOvers[fromNodeKey], fm)
}
}
}

View File

@ -0,0 +1,174 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"testing"
"github.com/cloudwego/eino/compose"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func TestAddFieldMappingsWithDeduplication(t *testing.T) {
tests := []struct {
name string
initialCarryOvers map[vo.NodeKey][]*compose.FieldMapping
fromNodeKey vo.NodeKey
fieldMappings []*compose.FieldMapping
expectedCount int
description string
}{
{
name: "empty_carry_overs",
initialCarryOvers: make(map[vo.NodeKey][]*compose.FieldMapping),
fromNodeKey: "node1",
fieldMappings: []*compose.FieldMapping{
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}),
},
expectedCount: 2,
description: "should add all mappings when carryOvers is empty",
},
{
name: "no_duplicates",
initialCarryOvers: map[vo.NodeKey][]*compose.FieldMapping{
"node1": {
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
},
},
fromNodeKey: "node1",
fieldMappings: []*compose.FieldMapping{
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}),
compose.MapFieldPaths(compose.FieldPath{"input3"}, compose.FieldPath{"output3"}),
},
expectedCount: 3,
description: "should add new mappings when no duplicates exist",
},
{
name: "with_duplicates",
initialCarryOvers: map[vo.NodeKey][]*compose.FieldMapping{
"node1": {
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}),
},
},
fromNodeKey: "node1",
fieldMappings: []*compose.FieldMapping{
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}), // duplicate
compose.MapFieldPaths(compose.FieldPath{"input3"}, compose.FieldPath{"output3"}), // new
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}), // duplicate
},
expectedCount: 3,
description: "should skip duplicates and only add new mappings",
},
{
name: "all_duplicates",
initialCarryOvers: map[vo.NodeKey][]*compose.FieldMapping{
"node1": {
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}),
},
},
fromNodeKey: "node1",
fieldMappings: []*compose.FieldMapping{
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}),
},
expectedCount: 2,
description: "should not add any mappings when all are duplicates",
},
{
name: "new_node_key",
initialCarryOvers: map[vo.NodeKey][]*compose.FieldMapping{
"node1": {
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
},
},
fromNodeKey: "node2",
fieldMappings: []*compose.FieldMapping{
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
compose.MapFieldPaths(compose.FieldPath{"input2"}, compose.FieldPath{"output2"}),
},
expectedCount: 2,
description: "should add all mappings for new node key",
},
{
name: "empty_field_mappings",
initialCarryOvers: map[vo.NodeKey][]*compose.FieldMapping{
"node1": {
compose.MapFieldPaths(compose.FieldPath{"input1"}, compose.FieldPath{"output1"}),
},
},
fromNodeKey: "node1",
fieldMappings: []*compose.FieldMapping{},
expectedCount: 1,
description: "should not change carryOvers when fieldMappings is empty",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Make a copy of initial carryOvers to avoid modifying the test data
carryOvers := make(map[vo.NodeKey][]*compose.FieldMapping)
for k, v := range tt.initialCarryOvers {
carryOvers[k] = make([]*compose.FieldMapping, len(v))
copy(carryOvers[k], v)
}
// Call the function under test
addFieldMappingsWithDeduplication(carryOvers, tt.fromNodeKey, tt.fieldMappings)
// Verify the result
actualCount := len(carryOvers[tt.fromNodeKey])
assert.Equal(t, tt.expectedCount, actualCount, tt.description)
// Verify no duplicates exist in the result
mappings := carryOvers[tt.fromNodeKey]
for i := 0; i < len(mappings); i++ {
for j := i + 1; j < len(mappings); j++ {
assert.False(t, mappings[i].Equals(mappings[j]),
"found duplicate mappings at indices %d and %d", i, j)
}
}
})
}
}
func TestAddFieldMappingsWithDeduplication_NilSafety(t *testing.T) {
t.Run("nil_field_mappings", func(t *testing.T) {
carryOvers := make(map[vo.NodeKey][]*compose.FieldMapping)
fromNodeKey := vo.NodeKey("node1")
// Should not panic with nil fieldMappings
assert.NotPanics(t, func() {
addFieldMappingsWithDeduplication(carryOvers, fromNodeKey, nil)
})
// Should initialize empty slice for the node
assert.NotNil(t, carryOvers[fromNodeKey])
assert.Equal(t, 0, len(carryOvers[fromNodeKey]))
})
t.Run("nil_carry_overs", func(t *testing.T) {
// Should panic with nil carryOvers - this is expected behavior
assert.Panics(t, func() {
addFieldMappingsWithDeduplication(nil, "node1", []*compose.FieldMapping{})
})
})
}

View File

@ -79,8 +79,8 @@ require (
gorm.io/driver/mysql v1.5.7
gorm.io/driver/sqlite v1.4.3
gorm.io/gen v0.3.26
gorm.io/gorm v1.30.0
gorm.io/plugin/dbresolver v1.6.0
gorm.io/gorm v1.25.11
gorm.io/plugin/dbresolver v1.5.2
)
require (
@ -289,5 +289,6 @@ require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/eino-contrib/jsonschema v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
)

View File

@ -1674,6 +1674,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c h1:jWdr7cHgl8c/ua5vYbR2WhSp+NQmzhsj0xoY3foTzW8=
gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c/go.mod h1:SH2K9R+2RMjuX1CkCONrPwoe9JzVv2hkQvEu4bXGojE=
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314=
@ -1689,12 +1690,12 @@ gorm.io/gorm v1.21.15/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.22.2/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg=
gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gorm.io/hints v1.1.0 h1:Lp4z3rxREufSdxn4qmkK3TLDltrM10FLTHiuqwDPvXw=
gorm.io/hints v1.1.0/go.mod h1:lKQ0JjySsPBj3uslFzY3JhYDtqEwzm+G1hv8rWujB6Y=
gorm.io/plugin/dbresolver v1.6.0 h1:XvKDeOtTn1EIX6s4SrKpEH82q0gXVemhYjbYZFGFVcw=
gorm.io/plugin/dbresolver v1.6.0/go.mod h1:tctw63jdrOezFR9HmrKnPkmig3m5Edem9fdxk9bQSzM=
gorm.io/plugin/dbresolver v1.5.2 h1:Iut7lW4TXNoVs++I+ra3zxjSxTRj4ocIeFEVp4lLhII=
gorm.io/plugin/dbresolver v1.5.2/go.mod h1:jPh59GOQbO7v7v28ZKZPd45tr+u3vyT+8tHdfdfOWcU=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -21,12 +21,12 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"github.com/coze-dev/coze-studio/backend/pkg/parsex"
"github.com/elastic/go-elasticsearch/v7"
"github.com/elastic/go-elasticsearch/v7/esapi"
"github.com/elastic/go-elasticsearch/v7/esutil"
"io"
"os"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
@ -39,12 +39,14 @@ type es7Client struct {
}
func newES7() (Client, error) {
esAddr := os.Getenv("ES_ADDR")
addresses, err := parsex.ParseClusterEndpoints(os.Getenv("ES_ADDR"))
if err != nil {
return nil, err
}
esUsername := os.Getenv("ES_USERNAME")
esPassword := os.Getenv("ES_PASSWORD")
esClient, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{esAddr},
Addresses: addresses,
Username: esUsername,
Password: esPassword,
})
@ -120,6 +122,10 @@ func (c *es7Client) CreateIndex(ctx context.Context, index string, properties ma
"mappings": map[string]any{
"properties": properties,
},
"settings": map[string]any{
"number_of_shards": parsex.GetEnvDefaultIntSetting("ES_NUMBER_OF_SHARDS", "1"),
"number_of_replicas": parsex.GetEnvDefaultIntSetting("ES_NUMBER_OF_REPLICAS", "1"),
},
}
body, err := json.Marshal(mapping)

View File

@ -19,8 +19,7 @@ package es
import (
"context"
"fmt"
"os"
"github.com/coze-dev/coze-studio/backend/pkg/parsex"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esutil"
"github.com/elastic/go-elasticsearch/v8/typedapi/core/search"
@ -31,6 +30,7 @@ import (
"github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/operator"
"github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/sortorder"
"github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/textquerytype"
"os"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
@ -51,11 +51,14 @@ type es8BulkIndexer struct {
type es8Types struct{}
func newES8() (Client, error) {
esAddr := os.Getenv("ES_ADDR")
addresses, err := parsex.ParseClusterEndpoints(os.Getenv("ES_ADDR"))
if err != nil {
return nil, err
}
esUsername := os.Getenv("ES_USERNAME")
esPassword := os.Getenv("ES_PASSWORD")
esClient, err := elasticsearch.NewTypedClient(elasticsearch.Config{
Addresses: []string{esAddr},
Addresses: addresses,
Username: esUsername,
Password: esPassword,
})
@ -239,6 +242,10 @@ func (c *es8Client) CreateIndex(ctx context.Context, index string, properties ma
Mappings: &types.TypeMapping{
Properties: propertiesMap,
},
Settings: &types.IndexSettings{
NumberOfShards: parsex.GetEnvDefaultIntSetting("ES_NUMBER_OF_SHARDS", "1"),
NumberOfReplicas: parsex.GetEnvDefaultIntSetting("ES_NUMBER_OF_REPLICAS", "1"),
},
}).Do(ctx); err != nil {
return err
}

View File

@ -18,9 +18,8 @@ package es
import (
"fmt"
"os"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"os"
)
type (

View File

@ -930,14 +930,23 @@ func (m *mysqlService) buildWhereClause(condition *rdb.ComplexCondition) (string
if condition == nil {
return "", nil, nil
}
if condition.Operator == "" {
condition.Operator = entity2.AND
}
if len(condition.NestedConditions) > 0 {
return m.buildNestedConditions(condition)
} else if len(condition.Conditions) > 0 {
whereClauseString, values, err := m.buildWhereCondition(condition)
return " WHERE " + whereClauseString, values, err
} else {
return "", nil, fmt.Errorf("empty condition: no nested or direct conditions found")
}
}
func (m *mysqlService) buildWhereCondition(condition *rdb.ComplexCondition) (string, []interface{}, error) {
var whereClause strings.Builder
values := make([]interface{}, 0)
for i, cond := range condition.Conditions {
if i > 0 {
whereClause.WriteString(fmt.Sprintf(" %s ", condition.Operator))
@ -971,25 +980,35 @@ func (m *mysqlService) buildWhereClause(condition *rdb.ComplexCondition) (string
values = append(values, cond.Value)
}
}
if len(condition.NestedConditions) > 0 {
whereClause.WriteString(" AND (")
for i, nested := range condition.NestedConditions {
if i > 0 {
whereClause.WriteString(fmt.Sprintf(" %s ", nested.Operator))
}
nestedClause, nestedValues, err := m.buildWhereClause(nested)
if err != nil {
return "", nil, err
}
whereClause.WriteString(nestedClause)
values = append(values, nestedValues...)
}
whereClause.WriteString(")")
}
if whereClause.Len() > 0 {
return " WHERE " + whereClause.String(), values, nil
return whereClause.String(), values, nil
}
return "", values, nil
}
func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (string, []interface{}, error) {
var whereClause strings.Builder
values := make([]interface{}, 0)
whereClause.WriteString(" WHERE (")
for i, nested := range condition.NestedConditions {
if i > 0 {
whereClause.WriteString(fmt.Sprintf(" %s ", nested.Operator))
}
nestedClause, nestedValues, err := m.buildWhereCondition(nested)
if err != nil {
return "", nil, err
}
whereClause.WriteString(nestedClause)
if i < len(condition.NestedConditions)-1 {
whereClause.WriteString(" " + string(condition.Operator))
}
values = append(values, nestedValues...)
}
whereClause.WriteString(")")
if whereClause.Len() > 0 {
return whereClause.String(), values, nil
}
return "", values, nil
}

View File

@ -0,0 +1,355 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: domain/agent/singleagent/service/single_agent.go
//
// Generated by this command:
//
// mockgen -destination internal/mock/domain/agent/singleagent/single_agent_mock.go --package mock -source domain/agent/singleagent/service/single_agent.go
//
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
schema "github.com/cloudwego/eino/schema"
playground "github.com/coze-dev/coze-studio/backend/api/model/playground"
entity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
gomock "go.uber.org/mock/gomock"
)
// MockSingleAgent is a mock of SingleAgent interface.
type MockSingleAgent struct {
ctrl *gomock.Controller
recorder *MockSingleAgentMockRecorder
isgomock struct{}
}
// MockSingleAgentMockRecorder is the mock recorder for MockSingleAgent.
type MockSingleAgentMockRecorder struct {
mock *MockSingleAgent
}
// NewMockSingleAgent creates a new mock instance.
func NewMockSingleAgent(ctrl *gomock.Controller) *MockSingleAgent {
mock := &MockSingleAgent{ctrl: ctrl}
mock.recorder = &MockSingleAgentMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSingleAgent) EXPECT() *MockSingleAgentMockRecorder {
return m.recorder
}
// CreateSingleAgent mocks base method.
func (m *MockSingleAgent) CreateSingleAgent(ctx context.Context, connectorID int64, version string, e *entity.SingleAgent) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSingleAgent", ctx, connectorID, version, e)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSingleAgent indicates an expected call of CreateSingleAgent.
func (mr *MockSingleAgentMockRecorder) CreateSingleAgent(ctx, connectorID, version, e any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSingleAgent", reflect.TypeOf((*MockSingleAgent)(nil).CreateSingleAgent), ctx, connectorID, version, e)
}
// CreateSingleAgentDraft mocks base method.
func (m *MockSingleAgent) CreateSingleAgentDraft(ctx context.Context, creatorID int64, draft *entity.SingleAgent) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSingleAgentDraft", ctx, creatorID, draft)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSingleAgentDraft indicates an expected call of CreateSingleAgentDraft.
func (mr *MockSingleAgentMockRecorder) CreateSingleAgentDraft(ctx, creatorID, draft any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).CreateSingleAgentDraft), ctx, creatorID, draft)
}
// CreateSingleAgentDraftWithID mocks base method.
func (m *MockSingleAgent) CreateSingleAgentDraftWithID(ctx context.Context, creatorID, agentID int64, draft *entity.SingleAgent) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSingleAgentDraftWithID", ctx, creatorID, agentID, draft)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSingleAgentDraftWithID indicates an expected call of CreateSingleAgentDraftWithID.
func (mr *MockSingleAgentMockRecorder) CreateSingleAgentDraftWithID(ctx, creatorID, agentID, draft any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSingleAgentDraftWithID", reflect.TypeOf((*MockSingleAgent)(nil).CreateSingleAgentDraftWithID), ctx, creatorID, agentID, draft)
}
// DeleteAgentDraft mocks base method.
func (m *MockSingleAgent) DeleteAgentDraft(ctx context.Context, spaceID, agentID int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAgentDraft", ctx, spaceID, agentID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAgentDraft indicates an expected call of DeleteAgentDraft.
func (mr *MockSingleAgentMockRecorder) DeleteAgentDraft(ctx, spaceID, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).DeleteAgentDraft), ctx, spaceID, agentID)
}
// DuplicateInMemory mocks base method.
func (m *MockSingleAgent) DuplicateInMemory(ctx context.Context, req *entity.DuplicateInfo) (*entity.SingleAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DuplicateInMemory", ctx, req)
ret0, _ := ret[0].(*entity.SingleAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DuplicateInMemory indicates an expected call of DuplicateInMemory.
func (mr *MockSingleAgentMockRecorder) DuplicateInMemory(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DuplicateInMemory", reflect.TypeOf((*MockSingleAgent)(nil).DuplicateInMemory), ctx, req)
}
// GetAgentDraftDisplayInfo mocks base method.
func (m *MockSingleAgent) GetAgentDraftDisplayInfo(ctx context.Context, userID, agentID int64) (*entity.AgentDraftDisplayInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAgentDraftDisplayInfo", ctx, userID, agentID)
ret0, _ := ret[0].(*entity.AgentDraftDisplayInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAgentDraftDisplayInfo indicates an expected call of GetAgentDraftDisplayInfo.
func (mr *MockSingleAgentMockRecorder) GetAgentDraftDisplayInfo(ctx, userID, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentDraftDisplayInfo", reflect.TypeOf((*MockSingleAgent)(nil).GetAgentDraftDisplayInfo), ctx, userID, agentID)
}
// GetAgentPopupCount mocks base method.
func (m *MockSingleAgent) GetAgentPopupCount(ctx context.Context, uid, agentID int64, agentPopupType playground.BotPopupType) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAgentPopupCount", ctx, uid, agentID, agentPopupType)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAgentPopupCount indicates an expected call of GetAgentPopupCount.
func (mr *MockSingleAgentMockRecorder) GetAgentPopupCount(ctx, uid, agentID, agentPopupType any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentPopupCount", reflect.TypeOf((*MockSingleAgent)(nil).GetAgentPopupCount), ctx, uid, agentID, agentPopupType)
}
// GetPublishConnectorList mocks base method.
func (m *MockSingleAgent) GetPublishConnectorList(ctx context.Context, agentID int64) (*entity.PublishConnectorData, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPublishConnectorList", ctx, agentID)
ret0, _ := ret[0].(*entity.PublishConnectorData)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPublishConnectorList indicates an expected call of GetPublishConnectorList.
func (mr *MockSingleAgentMockRecorder) GetPublishConnectorList(ctx, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublishConnectorList", reflect.TypeOf((*MockSingleAgent)(nil).GetPublishConnectorList), ctx, agentID)
}
// GetPublishedInfo mocks base method.
func (m *MockSingleAgent) GetPublishedInfo(ctx context.Context, agentID int64) (*entity.PublishInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPublishedInfo", ctx, agentID)
ret0, _ := ret[0].(*entity.PublishInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPublishedInfo indicates an expected call of GetPublishedInfo.
func (mr *MockSingleAgentMockRecorder) GetPublishedInfo(ctx, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublishedInfo", reflect.TypeOf((*MockSingleAgent)(nil).GetPublishedInfo), ctx, agentID)
}
// GetPublishedTime mocks base method.
func (m *MockSingleAgent) GetPublishedTime(ctx context.Context, agentID int64) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPublishedTime", ctx, agentID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPublishedTime indicates an expected call of GetPublishedTime.
func (mr *MockSingleAgentMockRecorder) GetPublishedTime(ctx, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublishedTime", reflect.TypeOf((*MockSingleAgent)(nil).GetPublishedTime), ctx, agentID)
}
// GetSingleAgent mocks base method.
func (m *MockSingleAgent) GetSingleAgent(ctx context.Context, agentID int64, version string) (*entity.SingleAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSingleAgent", ctx, agentID, version)
ret0, _ := ret[0].(*entity.SingleAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSingleAgent indicates an expected call of GetSingleAgent.
func (mr *MockSingleAgentMockRecorder) GetSingleAgent(ctx, agentID, version any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSingleAgent", reflect.TypeOf((*MockSingleAgent)(nil).GetSingleAgent), ctx, agentID, version)
}
// GetSingleAgentDraft mocks base method.
func (m *MockSingleAgent) GetSingleAgentDraft(ctx context.Context, agentID int64) (*entity.SingleAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSingleAgentDraft", ctx, agentID)
ret0, _ := ret[0].(*entity.SingleAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSingleAgentDraft indicates an expected call of GetSingleAgentDraft.
func (mr *MockSingleAgentMockRecorder) GetSingleAgentDraft(ctx, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).GetSingleAgentDraft), ctx, agentID)
}
// IncrAgentPopupCount mocks base method.
func (m *MockSingleAgent) IncrAgentPopupCount(ctx context.Context, uid, agentID int64, agentPopupType playground.BotPopupType) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IncrAgentPopupCount", ctx, uid, agentID, agentPopupType)
ret0, _ := ret[0].(error)
return ret0
}
// IncrAgentPopupCount indicates an expected call of IncrAgentPopupCount.
func (mr *MockSingleAgentMockRecorder) IncrAgentPopupCount(ctx, uid, agentID, agentPopupType any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrAgentPopupCount", reflect.TypeOf((*MockSingleAgent)(nil).IncrAgentPopupCount), ctx, uid, agentID, agentPopupType)
}
// ListAgentPublishHistory mocks base method.
func (m *MockSingleAgent) ListAgentPublishHistory(ctx context.Context, agentID int64, pageIndex, pageSize int32, connectorID *int64) ([]*entity.SingleAgentPublish, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAgentPublishHistory", ctx, agentID, pageIndex, pageSize, connectorID)
ret0, _ := ret[0].([]*entity.SingleAgentPublish)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAgentPublishHistory indicates an expected call of ListAgentPublishHistory.
func (mr *MockSingleAgentMockRecorder) ListAgentPublishHistory(ctx, agentID, pageIndex, pageSize, connectorID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAgentPublishHistory", reflect.TypeOf((*MockSingleAgent)(nil).ListAgentPublishHistory), ctx, agentID, pageIndex, pageSize, connectorID)
}
// MGetSingleAgentDraft mocks base method.
func (m *MockSingleAgent) MGetSingleAgentDraft(ctx context.Context, agentIDs []int64) ([]*entity.SingleAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetSingleAgentDraft", ctx, agentIDs)
ret0, _ := ret[0].([]*entity.SingleAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetSingleAgentDraft indicates an expected call of MGetSingleAgentDraft.
func (mr *MockSingleAgentMockRecorder) MGetSingleAgentDraft(ctx, agentIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).MGetSingleAgentDraft), ctx, agentIDs)
}
// ObtainAgentByIdentity mocks base method.
func (m *MockSingleAgent) ObtainAgentByIdentity(ctx context.Context, identity *entity.AgentIdentity) (*entity.SingleAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ObtainAgentByIdentity", ctx, identity)
ret0, _ := ret[0].(*entity.SingleAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ObtainAgentByIdentity indicates an expected call of ObtainAgentByIdentity.
func (mr *MockSingleAgentMockRecorder) ObtainAgentByIdentity(ctx, identity any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObtainAgentByIdentity", reflect.TypeOf((*MockSingleAgent)(nil).ObtainAgentByIdentity), ctx, identity)
}
// SavePublishRecord mocks base method.
func (m *MockSingleAgent) SavePublishRecord(ctx context.Context, p *entity.SingleAgentPublish, e *entity.SingleAgent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePublishRecord", ctx, p, e)
ret0, _ := ret[0].(error)
return ret0
}
// SavePublishRecord indicates an expected call of SavePublishRecord.
func (mr *MockSingleAgentMockRecorder) SavePublishRecord(ctx, p, e any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePublishRecord", reflect.TypeOf((*MockSingleAgent)(nil).SavePublishRecord), ctx, p, e)
}
// StreamExecute mocks base method.
func (m *MockSingleAgent) StreamExecute(ctx context.Context, req *entity.ExecuteRequest) (*schema.StreamReader[*entity.AgentEvent], error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StreamExecute", ctx, req)
ret0, _ := ret[0].(*schema.StreamReader[*entity.AgentEvent])
ret1, _ := ret[1].(error)
return ret0, ret1
}
// StreamExecute indicates an expected call of StreamExecute.
func (mr *MockSingleAgentMockRecorder) StreamExecute(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamExecute", reflect.TypeOf((*MockSingleAgent)(nil).StreamExecute), ctx, req)
}
// UpdateAgentDraftDisplayInfo mocks base method.
func (m *MockSingleAgent) UpdateAgentDraftDisplayInfo(ctx context.Context, userID int64, e *entity.AgentDraftDisplayInfo) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAgentDraftDisplayInfo", ctx, userID, e)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAgentDraftDisplayInfo indicates an expected call of UpdateAgentDraftDisplayInfo.
func (mr *MockSingleAgentMockRecorder) UpdateAgentDraftDisplayInfo(ctx, userID, e any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAgentDraftDisplayInfo", reflect.TypeOf((*MockSingleAgent)(nil).UpdateAgentDraftDisplayInfo), ctx, userID, e)
}
// UpdateSingleAgentDraft mocks base method.
func (m *MockSingleAgent) UpdateSingleAgentDraft(ctx context.Context, agentInfo *entity.SingleAgent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateSingleAgentDraft", ctx, agentInfo)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateSingleAgentDraft indicates an expected call of UpdateSingleAgentDraft.
func (mr *MockSingleAgentMockRecorder) UpdateSingleAgentDraft(ctx, agentInfo any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).UpdateSingleAgentDraft), ctx, agentInfo)
}

View File

@ -0,0 +1,148 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: domain/conversation/agentrun/service/agent_run.go
//
// Generated by this command:
//
// mockgen -destination internal/mock/domain/conversation/agentrun/agent_run_mock.go --package mock_agentrun -source domain/conversation/agentrun/service/agent_run.go
//
// Package mock_agentrun is a generated GoMock package.
package mock_agentrun
import (
context "context"
reflect "reflect"
schema "github.com/cloudwego/eino/schema"
entity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
gomock "go.uber.org/mock/gomock"
)
// MockRun is a mock of Run interface.
type MockRun struct {
ctrl *gomock.Controller
recorder *MockRunMockRecorder
isgomock struct{}
}
// MockRunMockRecorder is the mock recorder for MockRun.
type MockRunMockRecorder struct {
mock *MockRun
}
// NewMockRun creates a new mock instance.
func NewMockRun(ctrl *gomock.Controller) *MockRun {
mock := &MockRun{ctrl: ctrl}
mock.recorder = &MockRunMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRun) EXPECT() *MockRunMockRecorder {
return m.recorder
}
// AgentRun mocks base method.
func (m *MockRun) AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AgentRun", ctx, req)
ret0, _ := ret[0].(*schema.StreamReader[*entity.AgentRunResponse])
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AgentRun indicates an expected call of AgentRun.
func (mr *MockRunMockRecorder) AgentRun(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentRun", reflect.TypeOf((*MockRun)(nil).AgentRun), ctx, req)
}
// Cancel mocks base method.
func (m *MockRun) Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel", ctx, req)
ret0, _ := ret[0].(*entity.RunRecordMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Cancel indicates an expected call of Cancel.
func (mr *MockRunMockRecorder) Cancel(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockRun)(nil).Cancel), ctx, req)
}
// Create mocks base method.
func (m *MockRun) Create(ctx context.Context, runRecord *entity.AgentRunMeta) (*entity.RunRecordMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", ctx, runRecord)
ret0, _ := ret[0].(*entity.RunRecordMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Create indicates an expected call of Create.
func (mr *MockRunMockRecorder) Create(ctx, runRecord any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockRun)(nil).Create), ctx, runRecord)
}
// Delete mocks base method.
func (m *MockRun) Delete(ctx context.Context, runID []int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", ctx, runID)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockRunMockRecorder) Delete(ctx, runID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockRun)(nil).Delete), ctx, runID)
}
// GetByID mocks base method.
func (m *MockRun) GetByID(ctx context.Context, runID int64) (*entity.RunRecordMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByID", ctx, runID)
ret0, _ := ret[0].(*entity.RunRecordMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByID indicates an expected call of GetByID.
func (mr *MockRunMockRecorder) GetByID(ctx, runID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockRun)(nil).GetByID), ctx, runID)
}
// List mocks base method.
func (m *MockRun) List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", ctx, ListMeta)
ret0, _ := ret[0].([]*entity.RunRecordMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockRunMockRecorder) List(ctx, ListMeta any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockRun)(nil).List), ctx, ListMeta)
}

View File

@ -0,0 +1,163 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: domain/conversation/conversation/service/conversation.go
//
// Generated by this command:
//
// mockgen -destination internal/mock/domain/conversation/conversation/conversation_mock.go --package mock_conversation -source domain/conversation/conversation/service/conversation.go
//
// Package mock_conversation is a generated GoMock package.
package mock_conversation
import (
context "context"
reflect "reflect"
entity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
gomock "go.uber.org/mock/gomock"
)
// MockConversation is a mock of Conversation interface.
type MockConversation struct {
ctrl *gomock.Controller
recorder *MockConversationMockRecorder
isgomock struct{}
}
// MockConversationMockRecorder is the mock recorder for MockConversation.
type MockConversationMockRecorder struct {
mock *MockConversation
}
// NewMockConversation creates a new mock instance.
func NewMockConversation(ctrl *gomock.Controller) *MockConversation {
mock := &MockConversation{ctrl: ctrl}
mock.recorder = &MockConversationMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConversation) EXPECT() *MockConversationMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockConversation) Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", ctx, req)
ret0, _ := ret[0].(*entity.Conversation)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Create indicates an expected call of Create.
func (mr *MockConversationMockRecorder) Create(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockConversation)(nil).Create), ctx, req)
}
// Delete mocks base method.
func (m *MockConversation) Delete(ctx context.Context, id int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockConversationMockRecorder) Delete(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockConversation)(nil).Delete), ctx, id)
}
// GetByID mocks base method.
func (m *MockConversation) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByID", ctx, id)
ret0, _ := ret[0].(*entity.Conversation)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByID indicates an expected call of GetByID.
func (mr *MockConversationMockRecorder) GetByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockConversation)(nil).GetByID), ctx, id)
}
// GetCurrentConversation mocks base method.
func (m *MockConversation) GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCurrentConversation", ctx, req)
ret0, _ := ret[0].(*entity.Conversation)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetCurrentConversation indicates an expected call of GetCurrentConversation.
func (mr *MockConversationMockRecorder) GetCurrentConversation(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentConversation", reflect.TypeOf((*MockConversation)(nil).GetCurrentConversation), ctx, req)
}
// List mocks base method.
func (m *MockConversation) List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", ctx, req)
ret0, _ := ret[0].([]*entity.Conversation)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// List indicates an expected call of List.
func (mr *MockConversationMockRecorder) List(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockConversation)(nil).List), ctx, req)
}
// NewConversationCtx mocks base method.
func (m *MockConversation) NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewConversationCtx", ctx, req)
ret0, _ := ret[0].(*entity.NewConversationCtxResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NewConversationCtx indicates an expected call of NewConversationCtx.
func (mr *MockConversationMockRecorder) NewConversationCtx(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConversationCtx", reflect.TypeOf((*MockConversation)(nil).NewConversationCtx), ctx, req)
}
// Update mocks base method.
func (m *MockConversation) Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", ctx, req)
ret0, _ := ret[0].(*entity.Conversation)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Update indicates an expected call of Update.
func (mr *MockConversationMockRecorder) Update(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockConversation)(nil).Update), ctx, req)
}

View File

@ -0,0 +1,132 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: domain/shortcutcmd/service/shortcut_cmd.go
//
// Generated by this command:
//
// mockgen -destination internal/mock/domain/shortcutcmd/shortcut_cmd_mock.go --package mock_shortcutcmd -source domain/shortcutcmd/service/shortcut_cmd.go
//
// Package mock_shortcutcmd is a generated GoMock package.
package mock_shortcutcmd
import (
context "context"
reflect "reflect"
entity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
gomock "go.uber.org/mock/gomock"
)
// MockShortcutCmd is a mock of ShortcutCmd interface.
type MockShortcutCmd struct {
ctrl *gomock.Controller
recorder *MockShortcutCmdMockRecorder
isgomock struct{}
}
// MockShortcutCmdMockRecorder is the mock recorder for MockShortcutCmd.
type MockShortcutCmdMockRecorder struct {
mock *MockShortcutCmd
}
// NewMockShortcutCmd creates a new mock instance.
func NewMockShortcutCmd(ctrl *gomock.Controller) *MockShortcutCmd {
mock := &MockShortcutCmd{ctrl: ctrl}
mock.recorder = &MockShortcutCmdMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockShortcutCmd) EXPECT() *MockShortcutCmdMockRecorder {
return m.recorder
}
// CreateCMD mocks base method.
func (m *MockShortcutCmd) CreateCMD(ctx context.Context, shortcut *entity.ShortcutCmd) (*entity.ShortcutCmd, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateCMD", ctx, shortcut)
ret0, _ := ret[0].(*entity.ShortcutCmd)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateCMD indicates an expected call of CreateCMD.
func (mr *MockShortcutCmdMockRecorder) CreateCMD(ctx, shortcut any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCMD", reflect.TypeOf((*MockShortcutCmd)(nil).CreateCMD), ctx, shortcut)
}
// GetByCmdID mocks base method.
func (m *MockShortcutCmd) GetByCmdID(ctx context.Context, cmdID int64, isOnline int32) (*entity.ShortcutCmd, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByCmdID", ctx, cmdID, isOnline)
ret0, _ := ret[0].(*entity.ShortcutCmd)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByCmdID indicates an expected call of GetByCmdID.
func (mr *MockShortcutCmdMockRecorder) GetByCmdID(ctx, cmdID, isOnline any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByCmdID", reflect.TypeOf((*MockShortcutCmd)(nil).GetByCmdID), ctx, cmdID, isOnline)
}
// ListCMD mocks base method.
func (m *MockShortcutCmd) ListCMD(ctx context.Context, lm *entity.ListMeta) ([]*entity.ShortcutCmd, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListCMD", ctx, lm)
ret0, _ := ret[0].([]*entity.ShortcutCmd)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListCMD indicates an expected call of ListCMD.
func (mr *MockShortcutCmdMockRecorder) ListCMD(ctx, lm any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListCMD", reflect.TypeOf((*MockShortcutCmd)(nil).ListCMD), ctx, lm)
}
// PublishCMDs mocks base method.
func (m *MockShortcutCmd) PublishCMDs(ctx context.Context, objID int64, cmdIDs []int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishCMDs", ctx, objID, cmdIDs)
ret0, _ := ret[0].(error)
return ret0
}
// PublishCMDs indicates an expected call of PublishCMDs.
func (mr *MockShortcutCmdMockRecorder) PublishCMDs(ctx, objID, cmdIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishCMDs", reflect.TypeOf((*MockShortcutCmd)(nil).PublishCMDs), ctx, objID, cmdIDs)
}
// UpdateCMD mocks base method.
func (m *MockShortcutCmd) UpdateCMD(ctx context.Context, shortcut *entity.ShortcutCmd) (*entity.ShortcutCmd, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateCMD", ctx, shortcut)
ret0, _ := ret[0].(*entity.ShortcutCmd)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateCMD indicates an expected call of UpdateCMD.
func (mr *MockShortcutCmdMockRecorder) UpdateCMD(ctx, shortcut any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCMD", reflect.TypeOf((*MockShortcutCmd)(nil).UpdateCMD), ctx, shortcut)
}

View File

@ -0,0 +1,118 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: domain/upload/service/interface.go
//
// Generated by this command:
//
// mockgen -destination internal/mock/domain/upload/upload_service_mock.go --package mock_upload -source domain/upload/service/interface.go UploadService
//
// Package mock_upload is a generated GoMock package.
package mock_upload
import (
context "context"
reflect "reflect"
service "github.com/coze-dev/coze-studio/backend/domain/upload/service"
gomock "go.uber.org/mock/gomock"
)
// MockUploadService is a mock of UploadService interface.
type MockUploadService struct {
ctrl *gomock.Controller
recorder *MockUploadServiceMockRecorder
isgomock struct{}
}
// MockUploadServiceMockRecorder is the mock recorder for MockUploadService.
type MockUploadServiceMockRecorder struct {
mock *MockUploadService
}
// NewMockUploadService creates a new mock instance.
func NewMockUploadService(ctrl *gomock.Controller) *MockUploadService {
mock := &MockUploadService{ctrl: ctrl}
mock.recorder = &MockUploadServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUploadService) EXPECT() *MockUploadServiceMockRecorder {
return m.recorder
}
// GetFile mocks base method.
func (m *MockUploadService) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetFile", ctx, req)
ret0, _ := ret[0].(*service.GetFileResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetFile indicates an expected call of GetFile.
func (mr *MockUploadServiceMockRecorder) GetFile(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploadService)(nil).GetFile), ctx, req)
}
// GetFiles mocks base method.
func (m *MockUploadService) GetFiles(ctx context.Context, req *service.GetFilesRequest) (*service.GetFilesResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetFiles", ctx, req)
ret0, _ := ret[0].(*service.GetFilesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetFiles indicates an expected call of GetFiles.
func (mr *MockUploadServiceMockRecorder) GetFiles(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFiles", reflect.TypeOf((*MockUploadService)(nil).GetFiles), ctx, req)
}
// UploadFile mocks base method.
func (m *MockUploadService) UploadFile(ctx context.Context, req *service.UploadFileRequest) (*service.UploadFileResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UploadFile", ctx, req)
ret0, _ := ret[0].(*service.UploadFileResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UploadFile indicates an expected call of UploadFile.
func (mr *MockUploadServiceMockRecorder) UploadFile(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFile", reflect.TypeOf((*MockUploadService)(nil).UploadFile), ctx, req)
}
// UploadFiles mocks base method.
func (m *MockUploadService) UploadFiles(ctx context.Context, req *service.UploadFilesRequest) (*service.UploadFilesResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UploadFiles", ctx, req)
ret0, _ := ret[0].(*service.UploadFilesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UploadFiles indicates an expected call of UploadFiles.
func (mr *MockUploadServiceMockRecorder) UploadFiles(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFiles", reflect.TypeOf((*MockUploadService)(nil).UploadFiles), ctx, req)
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parsex
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"os"
"strconv"
"strings"
)
// ParseClusterEndpoints 解析 ES /kafka 地址,多个地址用逗号分隔
func ParseClusterEndpoints(address string) ([]string, error) {
if strings.TrimSpace(address) == "" {
return nil, fmt.Errorf("endpoints environment variable is required")
}
endpoints := strings.Split(address, ",")
var validEndpoints []string
uniqueEndpoints := make(map[string]bool, len(endpoints))
for _, endpoint := range endpoints {
trimmed := strings.TrimSpace(endpoint)
if trimmed == "" {
continue
}
if !uniqueEndpoints[trimmed] {
uniqueEndpoints[trimmed] = true
validEndpoints = append(validEndpoints, trimmed)
}
}
if len(validEndpoints) == 0 {
return nil, fmt.Errorf("no valid endpoints found in: %s", address)
}
return validEndpoints, nil
}
// GetEnvDefaultIntSetting 获取环境变量的值,如果不存在或无效则返回默认值
func GetEnvDefaultIntSetting(envVar, defaultValue string) string {
value := os.Getenv(envVar)
if value == "" {
return defaultValue
}
if num, err := strconv.Atoi(value); err != nil || num <= 0 {
logs.Warnf("Invalid %s value: %s, using default: %s", envVar, value, defaultValue)
return defaultValue
}
return value
}

View File

@ -75,6 +75,8 @@ export ES_ADDR="http://127.0.0.1:9200"
export ES_VERSION="v8"
export ES_USERNAME=""
export ES_PASSWORD=""
export ES_NUMBER_OF_SHARDS = "1"
export ES_NUMBER_OF_REPLICAS = "1"
export COZE_MQ_TYPE="nsq" # nsq / kafka / rmq
@ -89,6 +91,9 @@ export RMQ_SECRET_KEY=""
export VECTOR_STORE_TYPE="milvus"
# milvus vector store
export MILVUS_ADDR="127.0.0.1:19530"
export MILVUS_USER=""
export MILVUS_PASSWORD=""
export MILVUS_TOKEN=""
# vikingdb vector store for Volcengine
export VIKING_DB_HOST=""
export VIKING_DB_REGION=""

View File

@ -71,6 +71,8 @@ export ES_ADDR="http://elasticsearch:9200"
export ES_VERSION="v8"
export ES_USERNAME=""
export ES_PASSWORD=""
export ES_NUMBER_OF_SHARDS = "1"
export ES_NUMBER_OF_REPLICAS = "1"
export COZE_MQ_TYPE="nsq" # nsq / kafka / rmq
@ -87,6 +89,7 @@ export VECTOR_STORE_TYPE="milvus"
export MILVUS_ADDR="milvus:19530"
export MILVUS_USER=""
export MILVUS_PASSWORD=""
export MILVUS_TOKEN=""
# vikingdb vector store for Volcengine
export VIKING_DB_HOST=""
export VIKING_DB_REGION=""