feat: workflow cross domain support conversation & other domain add related called methods
This commit is contained in:
@ -21,17 +21,31 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
)
|
||||
|
||||
// Requests and responses must not reference domain entities and can only use models under api/model/crossdomain.
|
||||
type SingleAgent interface {
|
||||
StreamExecute(ctx context.Context, historyMsg []*message.Message, query *message.Message,
|
||||
agentRuntime *singleagent.AgentRuntime) (*schema.StreamReader[*singleagent.AgentEvent], error)
|
||||
StreamExecute(ctx context.Context,
|
||||
agentRuntime *AgentRuntime) (*schema.StreamReader[*singleagent.AgentEvent], error)
|
||||
ObtainAgentByIdentity(ctx context.Context, identity *singleagent.AgentIdentity) (*singleagent.SingleAgent, error)
|
||||
}
|
||||
|
||||
type AgentRuntime struct {
|
||||
AgentVersion string
|
||||
UserID string
|
||||
AgentID int64
|
||||
IsDraft bool
|
||||
SpaceID int64
|
||||
ConnectorID int64
|
||||
PreRetrieveTools []*agentrun.Tool
|
||||
|
||||
HistoryMsg []*schema.Message
|
||||
Input *schema.Message
|
||||
ResumeInfo *ResumeInfo
|
||||
}
|
||||
|
||||
type ResumeInfo = singleagent.InterruptInfo
|
||||
|
||||
type AgentEvent = singleagent.AgentEvent
|
||||
|
||||
@ -20,10 +20,14 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
)
|
||||
|
||||
type Conversation interface {
|
||||
GetCurrentConversation(ctx context.Context, req *conversation.GetCurrent) (*conversation.Conversation, error)
|
||||
Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error)
|
||||
NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
|
||||
}
|
||||
|
||||
var defaultSVC Conversation
|
||||
|
||||
@ -20,13 +20,16 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
)
|
||||
|
||||
type Message interface {
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error)
|
||||
PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error)
|
||||
Create(ctx context.Context, msg *message.Message) (*message.Message, error)
|
||||
List(ctx context.Context, meta *entity.ListMeta) (*entity.ListResult, error)
|
||||
Edit(ctx context.Context, msg *message.Message) (*message.Message, error)
|
||||
Delete(ctx context.Context, req *entity.DeleteMeta) error
|
||||
}
|
||||
|
||||
var defaultSVC Message
|
||||
|
||||
@ -39,18 +39,30 @@ type Workflow interface {
|
||||
ReleaseApplicationWorkflows(ctx context.Context, appID int64, config *ReleaseWorkflowConfig) ([]*vo.ValidateIssue, error)
|
||||
GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error)
|
||||
SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
|
||||
StreamExecute(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*schema.StreamReader[*workflowEntity.Message], error)
|
||||
WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option
|
||||
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message])
|
||||
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
|
||||
}
|
||||
|
||||
type ExecuteConfig = vo.ExecuteConfig
|
||||
type WorkflowMessage = workflowEntity.Message
|
||||
type ExecuteMode = vo.ExecuteMode
|
||||
type NodeType = entity.NodeType
|
||||
type MessageType = entity.MessageType
|
||||
type InterruptEvent = workflowEntity.InterruptEvent
|
||||
type EventType = workflowEntity.InterruptEventType
|
||||
|
||||
type WorkflowMessage = entity.Message
|
||||
const (
|
||||
Answer MessageType = "answer"
|
||||
FunctionCall MessageType = "function_call"
|
||||
ToolResponse MessageType = "tool_response"
|
||||
)
|
||||
|
||||
const (
|
||||
NodeTypeOutputEmitter NodeType = "OutputEmitter"
|
||||
NodeTypeInputReceiver NodeType = "InputReceiver"
|
||||
NodeTypeQuestion NodeType = "Question"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -59,6 +71,14 @@ const (
|
||||
ExecuteModeNodeDebug ExecuteMode = "node_debug"
|
||||
)
|
||||
|
||||
type SyncPattern = vo.SyncPattern
|
||||
|
||||
const (
|
||||
SyncPatternSync SyncPattern = "sync"
|
||||
SyncPatternAsync SyncPattern = "async"
|
||||
SyncPatternStream SyncPattern = "stream"
|
||||
)
|
||||
|
||||
type TaskType = vo.TaskType
|
||||
|
||||
const (
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
conversation "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
|
||||
)
|
||||
|
||||
@ -40,3 +41,15 @@ func InitDomainService(c conversation.Conversation) crossconversation.Conversati
|
||||
func (s *impl) GetCurrentConversation(ctx context.Context, req *model.GetCurrent) (*model.Conversation, error) {
|
||||
return s.DomainSVC.GetCurrentConversation(ctx, req)
|
||||
}
|
||||
|
||||
func (s *impl) Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error) {
|
||||
return s.DomainSVC.Create(ctx, req)
|
||||
}
|
||||
|
||||
func (s *impl) NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error) {
|
||||
return s.DomainSVC.NewConversationCtx(ctx, req)
|
||||
}
|
||||
|
||||
func (s *impl) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
|
||||
return s.DomainSVC.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
@ -21,6 +21,8 @@ import (
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
|
||||
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
|
||||
)
|
||||
|
||||
@ -53,3 +55,11 @@ func (c *impl) Edit(ctx context.Context, msg *model.Message) (*model.Message, er
|
||||
func (c *impl) PreCreate(ctx context.Context, msg *model.Message) (*model.Message, error) {
|
||||
return c.DomainSVC.PreCreate(ctx, msg)
|
||||
}
|
||||
|
||||
func (c *impl) List(ctx context.Context, lm *entity.ListMeta) (*entity.ListResult, error) {
|
||||
return c.DomainSVC.List(ctx, lm)
|
||||
}
|
||||
|
||||
func (c *impl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
|
||||
return c.DomainSVC.Delete(ctx, req)
|
||||
}
|
||||
|
||||
@ -18,17 +18,13 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@ -38,49 +34,34 @@ var defaultSVC crossagent.SingleAgent
|
||||
|
||||
type impl struct {
|
||||
DomainSVC singleagent.SingleAgent
|
||||
ImagexSVC imagex.ImageX
|
||||
}
|
||||
|
||||
func InitDomainService(c singleagent.SingleAgent, imagexClient imagex.ImageX) crossagent.SingleAgent {
|
||||
func InitDomainService(c singleagent.SingleAgent) crossagent.SingleAgent {
|
||||
defaultSVC = &impl{
|
||||
DomainSVC: c,
|
||||
ImagexSVC: imagexClient,
|
||||
}
|
||||
|
||||
return defaultSVC
|
||||
}
|
||||
|
||||
func (c *impl) StreamExecute(ctx context.Context, historyMsg []*message.Message,
|
||||
query *message.Message, agentRuntime *model.AgentRuntime,
|
||||
func (c *impl) StreamExecute(ctx context.Context, agentRuntime *crossagent.AgentRuntime,
|
||||
) (*schema.StreamReader[*model.AgentEvent], error) {
|
||||
|
||||
historyMsg = c.historyPairs(historyMsg)
|
||||
|
||||
singleAgentStreamExecReq := c.buildSingleAgentStreamExecuteReq(ctx, historyMsg, query, agentRuntime)
|
||||
singleAgentStreamExecReq := c.buildSingleAgentStreamExecuteReq(ctx, agentRuntime)
|
||||
|
||||
streamEvent, err := c.DomainSVC.StreamExecute(ctx, singleAgentStreamExecReq)
|
||||
logs.CtxInfof(ctx, "agent StreamExecute req:%v, streamEvent:%v, err:%v", conv.DebugJsonToStr(singleAgentStreamExecReq), streamEvent, err)
|
||||
return streamEvent, err
|
||||
}
|
||||
|
||||
func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, historyMsg []*message.Message,
|
||||
input *message.Message, agentRuntime *model.AgentRuntime,
|
||||
func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, agentRuntime *crossagent.AgentRuntime,
|
||||
) *model.ExecuteRequest {
|
||||
identity := c.buildIdentity(input, agentRuntime)
|
||||
inputBuild := c.buildSchemaMessage(ctx, []*message.Message{input})
|
||||
var inputSM *schema.Message
|
||||
if len(inputBuild) > 0 {
|
||||
inputSM = inputBuild[0]
|
||||
}
|
||||
history := c.buildSchemaMessage(ctx, historyMsg)
|
||||
|
||||
resumeInfo := c.checkResumeInfo(ctx, historyMsg)
|
||||
|
||||
return &model.ExecuteRequest{
|
||||
Identity: identity,
|
||||
Input: inputSM,
|
||||
History: history,
|
||||
UserID: input.UserID,
|
||||
Identity: c.buildIdentity(agentRuntime),
|
||||
Input: agentRuntime.Input,
|
||||
History: agentRuntime.HistoryMsg,
|
||||
UserID: agentRuntime.UserID,
|
||||
PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever {
|
||||
return &agentrun.ToolsRetriever{
|
||||
PluginID: tool.PluginID,
|
||||
@ -98,141 +79,19 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, historyMsg
|
||||
}(tool.Type),
|
||||
}
|
||||
}),
|
||||
|
||||
ResumeInfo: resumeInfo,
|
||||
ResumeInfo: agentRuntime.ResumeInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *impl) historyPairs(historyMsg []*message.Message) []*message.Message {
|
||||
|
||||
fcMsgPairs := make(map[int64][]*message.Message)
|
||||
for _, one := range historyMsg {
|
||||
if one.MessageType != message.MessageTypeFunctionCall && one.MessageType != message.MessageTypeToolResponse {
|
||||
continue
|
||||
}
|
||||
if _, ok := fcMsgPairs[one.RunID]; !ok {
|
||||
fcMsgPairs[one.RunID] = []*message.Message{one}
|
||||
} else {
|
||||
fcMsgPairs[one.RunID] = append(fcMsgPairs[one.RunID], one)
|
||||
}
|
||||
}
|
||||
|
||||
var historyAfterPairs []*message.Message
|
||||
for _, value := range historyMsg {
|
||||
if value.MessageType == message.MessageTypeFunctionCall {
|
||||
if len(fcMsgPairs[value.RunID])%2 == 0 {
|
||||
historyAfterPairs = append(historyAfterPairs, value)
|
||||
}
|
||||
} else {
|
||||
historyAfterPairs = append(historyAfterPairs, value)
|
||||
}
|
||||
}
|
||||
return historyAfterPairs
|
||||
|
||||
}
|
||||
func (c *impl) checkResumeInfo(_ context.Context, historyMsg []*message.Message) *crossagent.ResumeInfo {
|
||||
|
||||
var resumeInfo *crossagent.ResumeInfo
|
||||
for i := len(historyMsg) - 1; i >= 0; i-- {
|
||||
if historyMsg[i].MessageType == message.MessageTypeQuestion {
|
||||
break
|
||||
}
|
||||
if historyMsg[i].MessageType == message.MessageTypeVerbose {
|
||||
if historyMsg[i].Ext[string(entity.ExtKeyResumeInfo)] != "" {
|
||||
err := json.Unmarshal([]byte(historyMsg[i].Ext[string(entity.ExtKeyResumeInfo)]), &resumeInfo)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resumeInfo
|
||||
}
|
||||
|
||||
func (c *impl) buildSchemaMessage(ctx context.Context, msgs []*message.Message) []*schema.Message {
|
||||
schemaMessage := make([]*schema.Message, 0, len(msgs))
|
||||
|
||||
for _, msgOne := range msgs {
|
||||
if msgOne.ModelContent == "" {
|
||||
continue
|
||||
}
|
||||
if msgOne.MessageType == message.MessageTypeVerbose || msgOne.MessageType == message.MessageTypeFlowUp {
|
||||
continue
|
||||
}
|
||||
var sm *schema.Message
|
||||
err := json.Unmarshal([]byte(msgOne.ModelContent), &sm)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(sm.ReasoningContent) > 0 {
|
||||
sm.ReasoningContent = ""
|
||||
}
|
||||
|
||||
schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm))
|
||||
}
|
||||
|
||||
return schemaMessage
|
||||
}
|
||||
|
||||
func (c *impl) parseMessageURI(ctx context.Context, mcMsg *schema.Message) *schema.Message {
|
||||
if mcMsg.MultiContent == nil {
|
||||
return mcMsg
|
||||
}
|
||||
for k, one := range mcMsg.MultiContent {
|
||||
switch one.Type {
|
||||
case schema.ChatMessagePartTypeImageURL:
|
||||
|
||||
if one.ImageURL.URI != "" {
|
||||
url, err := c.ImagexSVC.GetResourceURL(ctx, one.ImageURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].ImageURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
case schema.ChatMessagePartTypeFileURL:
|
||||
if one.FileURL.URI != "" {
|
||||
url, err := c.ImagexSVC.GetResourceURL(ctx, one.FileURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].FileURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
case schema.ChatMessagePartTypeAudioURL:
|
||||
if one.AudioURL.URI != "" {
|
||||
url, err := c.ImagexSVC.GetResourceURL(ctx, one.AudioURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].AudioURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
case schema.ChatMessagePartTypeVideoURL:
|
||||
if one.VideoURL.URI != "" {
|
||||
url, err := c.ImagexSVC.GetResourceURL(ctx, one.VideoURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].VideoURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return mcMsg
|
||||
}
|
||||
|
||||
func (c *impl) buildIdentity(input *message.Message, agentRuntime *model.AgentRuntime) *model.AgentIdentity {
|
||||
func (c *impl) buildIdentity(agentRuntime *crossagent.AgentRuntime) *model.AgentIdentity {
|
||||
return &model.AgentIdentity{
|
||||
AgentID: input.AgentID,
|
||||
AgentID: agentRuntime.AgentID,
|
||||
Version: agentRuntime.AgentVersion,
|
||||
IsDraft: agentRuntime.IsDraft,
|
||||
ConnectorID: agentRuntime.ConnectorID,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *impl) GetSingleAgent(ctx context.Context, agentID int64, version string) (agent *model.SingleAgent, err error) {
|
||||
agentInfo, err := c.DomainSVC.GetSingleAgent(ctx, agentID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return agentInfo.SingleAgent, nil
|
||||
}
|
||||
|
||||
func (c *impl) ObtainAgentByIdentity(ctx context.Context, identity *model.AgentIdentity) (*model.SingleAgent, error) {
|
||||
agentInfo, err := c.DomainSVC.ObtainAgentByIdentity(ctx, identity)
|
||||
if err != nil {
|
||||
|
||||
@ -70,11 +70,17 @@ func (i *impl) WithResumeToolWorkflow(resumingEvent *workflowEntity.ToolInterrup
|
||||
func (i *impl) SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) {
|
||||
return i.DomainSVC.SyncExecute(ctx, config, input)
|
||||
}
|
||||
|
||||
func (i *impl) StreamExecute(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*schema.StreamReader[*workflowEntity.Message], error) {
|
||||
return i.DomainSVC.StreamExecute(ctx, config, input)
|
||||
}
|
||||
func (i *impl) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option {
|
||||
return i.DomainSVC.WithExecuteConfig(cfg)
|
||||
}
|
||||
|
||||
func (i *impl) InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error {
|
||||
return i.DomainSVC.InitApplicationDefaultConversationTemplate(ctx, spaceID, appID, userID)
|
||||
}
|
||||
|
||||
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
|
||||
return i.DomainSVC.WithMessagePipe()
|
||||
}
|
||||
|
||||
202
backend/crossdomain/workflow/conversation/conversation.go
Normal file
202
backend/crossdomain/workflow/conversation/conversation.go
Normal file
@ -0,0 +1,202 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
msgentity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type ConversationRepository struct {
|
||||
}
|
||||
|
||||
func NewConversationRepository() *ConversationRepository {
|
||||
return &ConversationRepository{}
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) CreateConversation(ctx context.Context, req *conversation.CreateConversationRequest) (int64, error) {
|
||||
ret, err := crossconversation.DefaultSVC().Create(ctx, &entity.CreateMeta{
|
||||
AgentID: req.AppID,
|
||||
UserID: req.UserID,
|
||||
ConnectorID: req.ConnectorID,
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return ret.ID, nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) CreateMessage(ctx context.Context, req *conversation.CreateMessageRequest) (int64, error) {
|
||||
msg := &message.Message{
|
||||
ConversationID: req.ConversationID,
|
||||
Role: schema.RoleType(req.Role),
|
||||
Content: req.Content,
|
||||
ContentType: message.ContentType(req.ContentType),
|
||||
UserID: strconv.FormatInt(req.UserID, 10),
|
||||
AgentID: req.AppID,
|
||||
RunID: req.RunID,
|
||||
}
|
||||
if msg.Role == "user" {
|
||||
msg.MessageType = message.MessageTypeQuestion
|
||||
} else {
|
||||
msg.MessageType = message.MessageTypeAnswer
|
||||
}
|
||||
ret, err := crossmessage.DefaultSVC().Create(ctx, msg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return ret.ID, nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) MessageList(ctx context.Context, req *conversation.MessageListRequest) (*conversation.MessageListResponse, error) {
|
||||
lm := &msgentity.ListMeta{
|
||||
ConversationID: req.ConversationID,
|
||||
Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe
|
||||
UserID: strconv.FormatInt(req.UserID, 10),
|
||||
AgentID: req.AppID,
|
||||
OrderBy: req.OrderBy,
|
||||
}
|
||||
if req.BeforeID != nil {
|
||||
lm.Cursor, _ = strconv.ParseInt(*req.BeforeID, 10, 64)
|
||||
lm.Direction = msgentity.ScrollPageDirectionPrev
|
||||
}
|
||||
if req.AfterID != nil {
|
||||
lm.Cursor, _ = strconv.ParseInt(*req.AfterID, 10, 64)
|
||||
lm.Direction = msgentity.ScrollPageDirectionNext
|
||||
}
|
||||
lm.Direction = msgentity.ScrollPageDirectionNext
|
||||
lr, err := crossmessage.DefaultSVC().List(ctx, lm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := &conversation.MessageListResponse{}
|
||||
|
||||
if lr.PrevCursor > 0 {
|
||||
response.FirstID = strconv.FormatInt(lr.PrevCursor, 10)
|
||||
}
|
||||
if lr.NextCursor > 0 {
|
||||
response.LastID = strconv.FormatInt(lr.NextCursor, 10)
|
||||
}
|
||||
if len(lr.Messages) == 0 {
|
||||
return response, nil
|
||||
}
|
||||
messages, err := convertMessage(lr.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Messages = messages
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) ClearConversationHistory(ctx context.Context, req *conversation.ClearConversationHistoryReq) error {
|
||||
_, err := crossconversation.DefaultSVC().NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
|
||||
ID: req.ConversationID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) DeleteMessage(ctx context.Context, req *conversation.DeleteMessageRequest) error {
|
||||
return crossmessage.DefaultSVC().Delete(ctx, &msgentity.DeleteMeta{
|
||||
MessageIDs: []int64{req.MessageID},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) EditMessage(ctx context.Context, req *conversation.EditMessageRequest) error {
|
||||
_, err := crossmessage.DefaultSVC().Edit(ctx, &msgentity.Message{
|
||||
ID: req.MessageID,
|
||||
ConversationID: req.ConversationID,
|
||||
Content: req.Content,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) GetLatestRunIDs(ctx context.Context, req *conversation.GetLatestRunIDsRequest) ([]int64, error) {
|
||||
return []int64{0}, nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) GetMessagesByRunIDs(ctx context.Context, req *conversation.GetMessagesByRunIDsRequest) (*conversation.GetMessagesByRunIDsResponse, error) {
|
||||
|
||||
messages, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, req.ConversationID, req.RunIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgs, err := convertMessage(messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conversation.GetMessagesByRunIDsResponse{
|
||||
Messages: msgs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertMessage(msgs []*msgentity.Message) ([]*conversation.Message, error) {
|
||||
messages := make([]*conversation.Message, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
msg := &conversation.Message{
|
||||
ID: m.ID,
|
||||
Role: m.Role,
|
||||
ContentType: string(m.ContentType)}
|
||||
|
||||
if m.MultiContent != nil {
|
||||
var mcs []*conversation.Content
|
||||
for _, c := range m.MultiContent {
|
||||
if c.FileData != nil {
|
||||
for _, fd := range c.FileData {
|
||||
mcs = append(mcs, &conversation.Content{
|
||||
Type: c.Type,
|
||||
Uri: ptr.Of(fd.URI),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
mcs = append(mcs, &conversation.Content{
|
||||
Type: c.Type,
|
||||
Text: ptr.Of(c.Text),
|
||||
})
|
||||
}
|
||||
}
|
||||
msg.MultiContent = mcs
|
||||
} else {
|
||||
msg.Text = ptr.Of(m.Content)
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
396
backend/crossdomain/workflow/conversation/conversation_test.go
Normal file
396
backend/crossdomain/workflow/conversation/conversation_test.go
Normal file
@ -0,0 +1,396 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
apimessage "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_convertMessage(t *testing.T) {
|
||||
type args struct {
|
||||
lr *entity.ListResult
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *conversation.MessageListResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "pure text",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 1,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "text",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 1,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "text",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "text", Text: ptr.Of("hello")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pure file",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 2,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "file",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "file",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "f_uri_1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 2,
|
||||
Role: schema.User,
|
||||
ContentType: "file",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "file", Uri: ptr.Of("f_uri_1")},
|
||||
{Type: "text", Text: ptr.Of("")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text and file",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 3,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "text_file",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "hello",
|
||||
},
|
||||
{
|
||||
Type: "file",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "f_uri_2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 3,
|
||||
Role: schema.User,
|
||||
ContentType: "text_file",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "text", Text: ptr.Of("hello")},
|
||||
{Type: "file", Uri: ptr.Of("f_uri_2")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple files",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 4,
|
||||
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "file",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "file",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "f_uri_3",
|
||||
},
|
||||
{
|
||||
URI: "f_uri_4",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 4,
|
||||
Role: schema.User,
|
||||
ContentType: "file",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "file", Uri: ptr.Of("f_uri_3")},
|
||||
{Type: "file", Uri: ptr.Of("f_uri_4")},
|
||||
{Type: "text", Text: ptr.Of("")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty text",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 5,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "text",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 5,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "text",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "text", Text: ptr.Of("")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pure image",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 6,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "image",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "image",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "image_uri_5",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 6,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "image",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "image", Uri: ptr.Of("image_uri_5")},
|
||||
{Type: "text", Text: ptr.Of("")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple images",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 7,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "image",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "image",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "file_id_6",
|
||||
},
|
||||
{
|
||||
URI: "file_id_7",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 7,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "image",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "image", Uri: ptr.Of("file_id_6")},
|
||||
{Type: "image", Uri: ptr.Of("file_id_7")},
|
||||
{Type: "text", Text: ptr.Of("")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed content",
|
||||
args: args{
|
||||
lr: &entity.ListResult{
|
||||
Messages: []*entity.Message{
|
||||
{
|
||||
ID: 8,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "mix",
|
||||
MultiContent: []*apimessage.InputMetaData{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "hello",
|
||||
},
|
||||
{
|
||||
Type: "image",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "file_id_8",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "file",
|
||||
FileData: []*apimessage.FileData{
|
||||
{
|
||||
URI: "file_id_9",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &conversation.MessageListResponse{
|
||||
Messages: []*conversation.Message{
|
||||
{
|
||||
ID: 8,
|
||||
Role: schema.User,
|
||||
|
||||
ContentType: "mix",
|
||||
MultiContent: []*conversation.Content{
|
||||
{Type: "text", Text: ptr.Of("hello")},
|
||||
{Type: "image", Uri: ptr.Of("file_id_8")},
|
||||
{Type: "file", Uri: ptr.Of("file_id_9")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs, err := convertMessage(tt.args.lr.Messages)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("convertMessage() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
for i, msg := range msgs {
|
||||
assert.Equal(t, msg.MultiContent, tt.want.Messages[i].MultiContent)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user