feat: workflow cross domain support conversation & other domain add related called methods

This commit is contained in:
zhuangjie.1125
2025-08-12 22:35:36 +08:00
parent 16d9d5bceb
commit 6d47bf37d2
10 changed files with 684 additions and 157 deletions

View File

@ -21,17 +21,31 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
)
// Requests and responses must not reference domain entities and can only use models under api/model/crossdomain.
type SingleAgent interface {
StreamExecute(ctx context.Context, historyMsg []*message.Message, query *message.Message,
agentRuntime *singleagent.AgentRuntime) (*schema.StreamReader[*singleagent.AgentEvent], error)
StreamExecute(ctx context.Context,
agentRuntime *AgentRuntime) (*schema.StreamReader[*singleagent.AgentEvent], error)
ObtainAgentByIdentity(ctx context.Context, identity *singleagent.AgentIdentity) (*singleagent.SingleAgent, error)
}
type AgentRuntime struct {
AgentVersion string
UserID string
AgentID int64
IsDraft bool
SpaceID int64
ConnectorID int64
PreRetrieveTools []*agentrun.Tool
HistoryMsg []*schema.Message
Input *schema.Message
ResumeInfo *ResumeInfo
}
type ResumeInfo = singleagent.InterruptInfo
type AgentEvent = singleagent.AgentEvent

View File

@ -20,10 +20,14 @@ import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
)
type Conversation interface {
GetCurrentConversation(ctx context.Context, req *conversation.GetCurrent) (*conversation.Conversation, error)
Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error)
NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error)
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
}
var defaultSVC Conversation

View File

@ -20,13 +20,16 @@ import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
)
type Message interface {
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error)
PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error)
Create(ctx context.Context, msg *message.Message) (*message.Message, error)
List(ctx context.Context, meta *entity.ListMeta) (*entity.ListResult, error)
Edit(ctx context.Context, msg *message.Message) (*message.Message, error)
Delete(ctx context.Context, req *entity.DeleteMeta) error
}
var defaultSVC Message

View File

@ -39,18 +39,30 @@ type Workflow interface {
ReleaseApplicationWorkflows(ctx context.Context, appID int64, config *ReleaseWorkflowConfig) ([]*vo.ValidateIssue, error)
GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error)
SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
StreamExecute(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*schema.StreamReader[*workflowEntity.Message], error)
WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message])
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
}
type ExecuteConfig = vo.ExecuteConfig
type WorkflowMessage = workflowEntity.Message
type ExecuteMode = vo.ExecuteMode
type NodeType = entity.NodeType
type MessageType = entity.MessageType
type InterruptEvent = workflowEntity.InterruptEvent
type EventType = workflowEntity.InterruptEventType
type WorkflowMessage = entity.Message
const (
Answer MessageType = "answer"
FunctionCall MessageType = "function_call"
ToolResponse MessageType = "tool_response"
)
const (
NodeTypeOutputEmitter NodeType = "OutputEmitter"
NodeTypeInputReceiver NodeType = "InputReceiver"
NodeTypeQuestion NodeType = "Question"
)
const (
@ -59,6 +71,14 @@ const (
ExecuteModeNodeDebug ExecuteMode = "node_debug"
)
type SyncPattern = vo.SyncPattern
const (
SyncPatternSync SyncPattern = "sync"
SyncPatternAsync SyncPattern = "async"
SyncPatternStream SyncPattern = "stream"
)
type TaskType = vo.TaskType
const (

View File

@ -21,6 +21,7 @@ import (
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
conversation "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
)
@ -40,3 +41,15 @@ func InitDomainService(c conversation.Conversation) crossconversation.Conversati
func (s *impl) GetCurrentConversation(ctx context.Context, req *model.GetCurrent) (*model.Conversation, error) {
return s.DomainSVC.GetCurrentConversation(ctx, req)
}
func (s *impl) Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error) {
return s.DomainSVC.Create(ctx, req)
}
func (s *impl) NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error) {
return s.DomainSVC.NewConversationCtx(ctx, req)
}
func (s *impl) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
return s.DomainSVC.GetByID(ctx, id)
}

View File

@ -21,6 +21,8 @@ import (
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
)
@ -53,3 +55,11 @@ func (c *impl) Edit(ctx context.Context, msg *model.Message) (*model.Message, er
func (c *impl) PreCreate(ctx context.Context, msg *model.Message) (*model.Message, error) {
return c.DomainSVC.PreCreate(ctx, msg)
}
func (c *impl) List(ctx context.Context, lm *entity.ListMeta) (*entity.ListResult, error) {
return c.DomainSVC.List(ctx, lm)
}
func (c *impl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
return c.DomainSVC.Delete(ctx, req)
}

View File

@ -18,17 +18,13 @@ package agent
import (
"context"
"encoding/json"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@ -38,49 +34,34 @@ var defaultSVC crossagent.SingleAgent
type impl struct {
DomainSVC singleagent.SingleAgent
ImagexSVC imagex.ImageX
}
func InitDomainService(c singleagent.SingleAgent, imagexClient imagex.ImageX) crossagent.SingleAgent {
func InitDomainService(c singleagent.SingleAgent) crossagent.SingleAgent {
defaultSVC = &impl{
DomainSVC: c,
ImagexSVC: imagexClient,
}
return defaultSVC
}
func (c *impl) StreamExecute(ctx context.Context, historyMsg []*message.Message,
query *message.Message, agentRuntime *model.AgentRuntime,
func (c *impl) StreamExecute(ctx context.Context, agentRuntime *crossagent.AgentRuntime,
) (*schema.StreamReader[*model.AgentEvent], error) {
historyMsg = c.historyPairs(historyMsg)
singleAgentStreamExecReq := c.buildSingleAgentStreamExecuteReq(ctx, historyMsg, query, agentRuntime)
singleAgentStreamExecReq := c.buildSingleAgentStreamExecuteReq(ctx, agentRuntime)
streamEvent, err := c.DomainSVC.StreamExecute(ctx, singleAgentStreamExecReq)
logs.CtxInfof(ctx, "agent StreamExecute req:%v, streamEvent:%v, err:%v", conv.DebugJsonToStr(singleAgentStreamExecReq), streamEvent, err)
return streamEvent, err
}
func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, historyMsg []*message.Message,
input *message.Message, agentRuntime *model.AgentRuntime,
func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, agentRuntime *crossagent.AgentRuntime,
) *model.ExecuteRequest {
identity := c.buildIdentity(input, agentRuntime)
inputBuild := c.buildSchemaMessage(ctx, []*message.Message{input})
var inputSM *schema.Message
if len(inputBuild) > 0 {
inputSM = inputBuild[0]
}
history := c.buildSchemaMessage(ctx, historyMsg)
resumeInfo := c.checkResumeInfo(ctx, historyMsg)
return &model.ExecuteRequest{
Identity: identity,
Input: inputSM,
History: history,
UserID: input.UserID,
Identity: c.buildIdentity(agentRuntime),
Input: agentRuntime.Input,
History: agentRuntime.HistoryMsg,
UserID: agentRuntime.UserID,
PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever {
return &agentrun.ToolsRetriever{
PluginID: tool.PluginID,
@ -98,141 +79,19 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, historyMsg
}(tool.Type),
}
}),
ResumeInfo: resumeInfo,
ResumeInfo: agentRuntime.ResumeInfo,
}
}
func (c *impl) historyPairs(historyMsg []*message.Message) []*message.Message {
fcMsgPairs := make(map[int64][]*message.Message)
for _, one := range historyMsg {
if one.MessageType != message.MessageTypeFunctionCall && one.MessageType != message.MessageTypeToolResponse {
continue
}
if _, ok := fcMsgPairs[one.RunID]; !ok {
fcMsgPairs[one.RunID] = []*message.Message{one}
} else {
fcMsgPairs[one.RunID] = append(fcMsgPairs[one.RunID], one)
}
}
var historyAfterPairs []*message.Message
for _, value := range historyMsg {
if value.MessageType == message.MessageTypeFunctionCall {
if len(fcMsgPairs[value.RunID])%2 == 0 {
historyAfterPairs = append(historyAfterPairs, value)
}
} else {
historyAfterPairs = append(historyAfterPairs, value)
}
}
return historyAfterPairs
}
func (c *impl) checkResumeInfo(_ context.Context, historyMsg []*message.Message) *crossagent.ResumeInfo {
var resumeInfo *crossagent.ResumeInfo
for i := len(historyMsg) - 1; i >= 0; i-- {
if historyMsg[i].MessageType == message.MessageTypeQuestion {
break
}
if historyMsg[i].MessageType == message.MessageTypeVerbose {
if historyMsg[i].Ext[string(entity.ExtKeyResumeInfo)] != "" {
err := json.Unmarshal([]byte(historyMsg[i].Ext[string(entity.ExtKeyResumeInfo)]), &resumeInfo)
if err != nil {
return nil
}
}
}
}
return resumeInfo
}
func (c *impl) buildSchemaMessage(ctx context.Context, msgs []*message.Message) []*schema.Message {
schemaMessage := make([]*schema.Message, 0, len(msgs))
for _, msgOne := range msgs {
if msgOne.ModelContent == "" {
continue
}
if msgOne.MessageType == message.MessageTypeVerbose || msgOne.MessageType == message.MessageTypeFlowUp {
continue
}
var sm *schema.Message
err := json.Unmarshal([]byte(msgOne.ModelContent), &sm)
if err != nil {
continue
}
if len(sm.ReasoningContent) > 0 {
sm.ReasoningContent = ""
}
schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm))
}
return schemaMessage
}
func (c *impl) parseMessageURI(ctx context.Context, mcMsg *schema.Message) *schema.Message {
if mcMsg.MultiContent == nil {
return mcMsg
}
for k, one := range mcMsg.MultiContent {
switch one.Type {
case schema.ChatMessagePartTypeImageURL:
if one.ImageURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.ImageURL.URI)
if err == nil {
mcMsg.MultiContent[k].ImageURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeFileURL:
if one.FileURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.FileURL.URI)
if err == nil {
mcMsg.MultiContent[k].FileURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeAudioURL:
if one.AudioURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.AudioURL.URI)
if err == nil {
mcMsg.MultiContent[k].AudioURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeVideoURL:
if one.VideoURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.VideoURL.URI)
if err == nil {
mcMsg.MultiContent[k].VideoURL.URL = url.URL
}
}
}
}
return mcMsg
}
func (c *impl) buildIdentity(input *message.Message, agentRuntime *model.AgentRuntime) *model.AgentIdentity {
func (c *impl) buildIdentity(agentRuntime *crossagent.AgentRuntime) *model.AgentIdentity {
return &model.AgentIdentity{
AgentID: input.AgentID,
AgentID: agentRuntime.AgentID,
Version: agentRuntime.AgentVersion,
IsDraft: agentRuntime.IsDraft,
ConnectorID: agentRuntime.ConnectorID,
}
}
func (c *impl) GetSingleAgent(ctx context.Context, agentID int64, version string) (agent *model.SingleAgent, err error) {
agentInfo, err := c.DomainSVC.GetSingleAgent(ctx, agentID, version)
if err != nil {
return nil, err
}
return agentInfo.SingleAgent, nil
}
func (c *impl) ObtainAgentByIdentity(ctx context.Context, identity *model.AgentIdentity) (*model.SingleAgent, error) {
agentInfo, err := c.DomainSVC.ObtainAgentByIdentity(ctx, identity)
if err != nil {

View File

@ -70,11 +70,17 @@ func (i *impl) WithResumeToolWorkflow(resumingEvent *workflowEntity.ToolInterrup
func (i *impl) SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) {
return i.DomainSVC.SyncExecute(ctx, config, input)
}
func (i *impl) StreamExecute(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*schema.StreamReader[*workflowEntity.Message], error) {
return i.DomainSVC.StreamExecute(ctx, config, input)
}
func (i *impl) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option {
return i.DomainSVC.WithExecuteConfig(cfg)
}
func (i *impl) InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error {
return i.DomainSVC.InitApplicationDefaultConversationTemplate(ctx, spaceID, appID, userID)
}
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
return i.DomainSVC.WithMessagePipe()
}

View File

@ -0,0 +1,202 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"strconv"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
msgentity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type ConversationRepository struct {
}
func NewConversationRepository() *ConversationRepository {
return &ConversationRepository{}
}
func (c *ConversationRepository) CreateConversation(ctx context.Context, req *conversation.CreateConversationRequest) (int64, error) {
ret, err := crossconversation.DefaultSVC().Create(ctx, &entity.CreateMeta{
AgentID: req.AppID,
UserID: req.UserID,
ConnectorID: req.ConnectorID,
Scene: common.Scene_SceneWorkflow,
})
if err != nil {
return 0, err
}
return ret.ID, nil
}
func (c *ConversationRepository) CreateMessage(ctx context.Context, req *conversation.CreateMessageRequest) (int64, error) {
msg := &message.Message{
ConversationID: req.ConversationID,
Role: schema.RoleType(req.Role),
Content: req.Content,
ContentType: message.ContentType(req.ContentType),
UserID: strconv.FormatInt(req.UserID, 10),
AgentID: req.AppID,
RunID: req.RunID,
}
if msg.Role == "user" {
msg.MessageType = message.MessageTypeQuestion
} else {
msg.MessageType = message.MessageTypeAnswer
}
ret, err := crossmessage.DefaultSVC().Create(ctx, msg)
if err != nil {
return 0, err
}
return ret.ID, nil
}
func (c *ConversationRepository) MessageList(ctx context.Context, req *conversation.MessageListRequest) (*conversation.MessageListResponse, error) {
lm := &msgentity.ListMeta{
ConversationID: req.ConversationID,
Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe
UserID: strconv.FormatInt(req.UserID, 10),
AgentID: req.AppID,
OrderBy: req.OrderBy,
}
if req.BeforeID != nil {
lm.Cursor, _ = strconv.ParseInt(*req.BeforeID, 10, 64)
lm.Direction = msgentity.ScrollPageDirectionPrev
}
if req.AfterID != nil {
lm.Cursor, _ = strconv.ParseInt(*req.AfterID, 10, 64)
lm.Direction = msgentity.ScrollPageDirectionNext
}
lm.Direction = msgentity.ScrollPageDirectionNext
lr, err := crossmessage.DefaultSVC().List(ctx, lm)
if err != nil {
return nil, err
}
response := &conversation.MessageListResponse{}
if lr.PrevCursor > 0 {
response.FirstID = strconv.FormatInt(lr.PrevCursor, 10)
}
if lr.NextCursor > 0 {
response.LastID = strconv.FormatInt(lr.NextCursor, 10)
}
if len(lr.Messages) == 0 {
return response, nil
}
messages, err := convertMessage(lr.Messages)
if err != nil {
return nil, err
}
response.Messages = messages
return response, nil
}
func (c *ConversationRepository) ClearConversationHistory(ctx context.Context, req *conversation.ClearConversationHistoryReq) error {
_, err := crossconversation.DefaultSVC().NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
ID: req.ConversationID,
})
if err != nil {
return err
}
return nil
}
func (c *ConversationRepository) DeleteMessage(ctx context.Context, req *conversation.DeleteMessageRequest) error {
return crossmessage.DefaultSVC().Delete(ctx, &msgentity.DeleteMeta{
MessageIDs: []int64{req.MessageID},
})
}
func (c *ConversationRepository) EditMessage(ctx context.Context, req *conversation.EditMessageRequest) error {
_, err := crossmessage.DefaultSVC().Edit(ctx, &msgentity.Message{
ID: req.MessageID,
ConversationID: req.ConversationID,
Content: req.Content,
})
if err != nil {
return err
}
return nil
}
func (c *ConversationRepository) GetLatestRunIDs(ctx context.Context, req *conversation.GetLatestRunIDsRequest) ([]int64, error) {
return []int64{0}, nil
}
func (c *ConversationRepository) GetMessagesByRunIDs(ctx context.Context, req *conversation.GetMessagesByRunIDsRequest) (*conversation.GetMessagesByRunIDsResponse, error) {
messages, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, req.ConversationID, req.RunIDs)
if err != nil {
return nil, err
}
msgs, err := convertMessage(messages)
if err != nil {
return nil, err
}
return &conversation.GetMessagesByRunIDsResponse{
Messages: msgs,
}, nil
}
func convertMessage(msgs []*msgentity.Message) ([]*conversation.Message, error) {
messages := make([]*conversation.Message, 0, len(msgs))
for _, m := range msgs {
msg := &conversation.Message{
ID: m.ID,
Role: m.Role,
ContentType: string(m.ContentType)}
if m.MultiContent != nil {
var mcs []*conversation.Content
for _, c := range m.MultiContent {
if c.FileData != nil {
for _, fd := range c.FileData {
mcs = append(mcs, &conversation.Content{
Type: c.Type,
Uri: ptr.Of(fd.URI),
})
}
} else {
mcs = append(mcs, &conversation.Content{
Type: c.Type,
Text: ptr.Of(c.Text),
})
}
}
msg.MultiContent = mcs
} else {
msg.Text = ptr.Of(m.Content)
}
messages = append(messages, msg)
}
return messages, nil
}

View File

@ -0,0 +1,396 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"testing"
"github.com/cloudwego/eino/schema"
apimessage "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/stretchr/testify/assert"
)
func Test_convertMessage(t *testing.T) {
type args struct {
lr *entity.ListResult
}
tests := []struct {
name string
args args
want *conversation.MessageListResponse
wantErr bool
}{
{
name: "pure text",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 1,
Role: schema.User,
ContentType: "text",
MultiContent: []*apimessage.InputMetaData{
{
Type: "text",
Text: "hello",
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 1,
Role: schema.User,
ContentType: "text",
MultiContent: []*conversation.Content{
{Type: "text", Text: ptr.Of("hello")},
},
},
},
},
},
{
name: "pure file",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 2,
Role: schema.User,
ContentType: "file",
MultiContent: []*apimessage.InputMetaData{
{
Type: "file",
FileData: []*apimessage.FileData{
{
URI: "f_uri_1",
},
},
},
{
Type: "text",
Text: "",
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 2,
Role: schema.User,
ContentType: "file",
MultiContent: []*conversation.Content{
{Type: "file", Uri: ptr.Of("f_uri_1")},
{Type: "text", Text: ptr.Of("")},
},
},
},
},
},
{
name: "text and file",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 3,
Role: schema.User,
ContentType: "text_file",
MultiContent: []*apimessage.InputMetaData{
{
Type: "text",
Text: "hello",
},
{
Type: "file",
FileData: []*apimessage.FileData{
{
URI: "f_uri_2",
},
},
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 3,
Role: schema.User,
ContentType: "text_file",
MultiContent: []*conversation.Content{
{Type: "text", Text: ptr.Of("hello")},
{Type: "file", Uri: ptr.Of("f_uri_2")},
},
},
},
},
},
{
name: "multiple files",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 4,
Role: schema.User,
ContentType: "file",
MultiContent: []*apimessage.InputMetaData{
{
Type: "file",
FileData: []*apimessage.FileData{
{
URI: "f_uri_3",
},
{
URI: "f_uri_4",
},
},
},
{
Type: "text",
Text: "",
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 4,
Role: schema.User,
ContentType: "file",
MultiContent: []*conversation.Content{
{Type: "file", Uri: ptr.Of("f_uri_3")},
{Type: "file", Uri: ptr.Of("f_uri_4")},
{Type: "text", Text: ptr.Of("")},
},
},
},
},
},
{
name: "empty text",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 5,
Role: schema.User,
ContentType: "text",
MultiContent: []*apimessage.InputMetaData{
{
Type: "text",
Text: "",
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 5,
Role: schema.User,
ContentType: "text",
MultiContent: []*conversation.Content{
{Type: "text", Text: ptr.Of("")},
},
},
},
},
},
{
name: "pure image",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 6,
Role: schema.User,
ContentType: "image",
MultiContent: []*apimessage.InputMetaData{
{
Type: "image",
FileData: []*apimessage.FileData{
{
URI: "image_uri_5",
},
},
},
{
Type: "text",
Text: "",
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 6,
Role: schema.User,
ContentType: "image",
MultiContent: []*conversation.Content{
{Type: "image", Uri: ptr.Of("image_uri_5")},
{Type: "text", Text: ptr.Of("")},
},
},
},
},
},
{
name: "multiple images",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 7,
Role: schema.User,
ContentType: "image",
MultiContent: []*apimessage.InputMetaData{
{
Type: "image",
FileData: []*apimessage.FileData{
{
URI: "file_id_6",
},
{
URI: "file_id_7",
},
},
},
{
Type: "text",
Text: "",
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 7,
Role: schema.User,
ContentType: "image",
MultiContent: []*conversation.Content{
{Type: "image", Uri: ptr.Of("file_id_6")},
{Type: "image", Uri: ptr.Of("file_id_7")},
{Type: "text", Text: ptr.Of("")},
},
},
},
},
},
{
name: "mixed content",
args: args{
lr: &entity.ListResult{
Messages: []*entity.Message{
{
ID: 8,
Role: schema.User,
ContentType: "mix",
MultiContent: []*apimessage.InputMetaData{
{
Type: "text",
Text: "hello",
},
{
Type: "image",
FileData: []*apimessage.FileData{
{
URI: "file_id_8",
},
},
},
{
Type: "file",
FileData: []*apimessage.FileData{
{
URI: "file_id_9",
},
},
},
},
},
},
},
},
want: &conversation.MessageListResponse{
Messages: []*conversation.Message{
{
ID: 8,
Role: schema.User,
ContentType: "mix",
MultiContent: []*conversation.Content{
{Type: "text", Text: ptr.Of("hello")},
{Type: "image", Uri: ptr.Of("file_id_8")},
{Type: "file", Uri: ptr.Of("file_id_9")},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs, err := convertMessage(tt.args.lr.Messages)
if (err != nil) != tt.wantErr {
t.Errorf("convertMessage() error = %v, wantErr %v", err, tt.wantErr)
return
}
for i, msg := range msgs {
assert.Equal(t, msg.MultiContent, tt.want.Messages[i].MultiContent)
}
})
}
}