refactor(workflow): Calculate chat history rounds during schema convertion (#1990)
Co-authored-by: zhuangjie.1125 <zhuangjie.1125@bytedance.com>
This commit is contained in:
@ -228,6 +228,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
h.POST("/api/workflow_api/chat_flow_role/delete", DeleteChatFlowRole)
|
||||
h.POST("/api/workflow_api/chat_flow_role/create", CreateChatFlowRole)
|
||||
h.GET("/api/workflow_api/chat_flow_role/get", GetChatFlowRole)
|
||||
h.POST("/v1/workflows/chat", OpenAPIChatFlowRun)
|
||||
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
mockIDGen := mock.NewMockIDGenerator(ctrl)
|
||||
@ -1082,6 +1083,46 @@ func (r *wfTestRunner) openapiResume(id string, eventID string, resumeData strin
|
||||
return re
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) openapiChatFlowRun(wfID string, cID, appID, botID *string, input any, additionalMessage []*workflow.EnterMessage) *sse.Reader {
|
||||
inputStr, _ := sonic.MarshalString(input)
|
||||
|
||||
req := &workflow.ChatFlowRunRequest{
|
||||
WorkflowID: wfID,
|
||||
Parameters: ptr.Of(inputStr),
|
||||
AdditionalMessages: additionalMessage,
|
||||
}
|
||||
if cID != nil {
|
||||
req.ConversationID = cID
|
||||
}
|
||||
if appID != nil {
|
||||
req.AppID = appID
|
||||
}
|
||||
if botID != nil {
|
||||
req.BotID = botID
|
||||
}
|
||||
|
||||
m, err := sonic.Marshal(req)
|
||||
assert.NoError(r.t, err)
|
||||
|
||||
c, _ := client.NewClient()
|
||||
hReq, hResp := protocol.AcquireRequest(), protocol.AcquireResponse()
|
||||
hReq.SetRequestURI("http://localhost:8888" + "/v1/workflows/chat")
|
||||
hReq.SetMethod("POST")
|
||||
hReq.SetBody(m)
|
||||
hReq.SetHeader("Content-Type", "application/json")
|
||||
err = c.Do(context.Background(), hReq, hResp)
|
||||
assert.NoError(r.t, err)
|
||||
|
||||
if hResp.StatusCode() != http.StatusOK {
|
||||
r.t.Errorf("unexpected status code: %d, body: %s", hResp.StatusCode(), string(hResp.Body()))
|
||||
}
|
||||
|
||||
re, err := sse.NewReader(hResp)
|
||||
assert.NoError(r.t, err)
|
||||
|
||||
return re
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) runServer() func() {
|
||||
go func() {
|
||||
_ = r.h.Run()
|
||||
@ -5491,7 +5532,7 @@ func TestConversationOfChatFlow(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
v.DefaultValue = cName
|
||||
}
|
||||
startNode.Data.Outputs[idx] = v
|
||||
@ -5522,7 +5563,7 @@ func TestConversationOfChatFlow(t *testing.T) {
|
||||
for _, vAny := range node.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
assert.NoError(t, err)
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
assert.Equal(t, v.DefaultValue, updateName)
|
||||
}
|
||||
}
|
||||
@ -5569,7 +5610,7 @@ func TestConversationOfChatFlow(t *testing.T) {
|
||||
for _, vAny := range node.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
assert.NoError(t, err)
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
assert.Equal(t, v.DefaultValue, cName+"copy")
|
||||
}
|
||||
}
|
||||
@ -5988,3 +6029,223 @@ func TestConversationHistoryNodes(t *testing.T) {
|
||||
assert.Equal(t, []any{}, outputMap["history_list"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatFlowRun(t *testing.T) {
|
||||
mockey.PatchConvey("chat flow run", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
appworkflow.SVC.IDGenerator = r.idGen
|
||||
defer r.closeFn()
|
||||
defer r.runServer()()
|
||||
|
||||
chatModel1 := &testutil.UTChatModel{
|
||||
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
sr := schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "I ",
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "don't know.",
|
||||
},
|
||||
})
|
||||
return sr, nil
|
||||
},
|
||||
}
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel1, nil, nil).AnyTimes()
|
||||
|
||||
id := r.load("chatflow/llm_chat.json", withMode(workflow.WorkflowMode_ChatFlow))
|
||||
r.publish(id, "v0.0.1", true)
|
||||
cID := time.Now().UnixNano()
|
||||
cIDStr := strconv.FormatInt(cID, 10)
|
||||
appID := time.Now().UnixNano()
|
||||
appIDStr := strconv.FormatInt(appID, 10)
|
||||
|
||||
// Create conversation first
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
idStr := r.load("conversation_manager/update_dynamic_conversation.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
ret, _ := r.openapiSyncRun(idStr, map[string]string{
|
||||
"input": "v1",
|
||||
"new_name": "v2",
|
||||
}, withRunProjectID(appID))
|
||||
assert.Equal(t, map[string]any{"conversationId": strconv.FormatInt(cID, 10), "isExisted": false, "isSuccess": true}, ret["obj"])
|
||||
|
||||
msg := []*workflow.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Content: "你好",
|
||||
},
|
||||
}
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
|
||||
t.Run("chat flow run in app", func(t *testing.T) {
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run in bot", func(t *testing.T) {
|
||||
botID := time.Now().UnixNano()
|
||||
botIDStr := strconv.FormatInt(botID, 10)
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), nil, ptr.Of(botIDStr), map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run without cID", func(t *testing.T) {
|
||||
sseReader := r.openapiChatFlowRun(id, nil, ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run with additional messages", func(t *testing.T) {
|
||||
additionalMsg := []*workflow.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Content: "你好, 我叫小明",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ContentType: "text",
|
||||
Content: "你好小明, 很高兴认识你",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Content: "你好",
|
||||
},
|
||||
}
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, additionalMsg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run with history messages", func(t *testing.T) {
|
||||
id := r.load("chatflow/llm_chat_with_history.json", withMode(workflow.WorkflowMode_ChatFlow))
|
||||
r.publish(id, "v0.0.1", true)
|
||||
r.message.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{rID}, nil).AnyTimes()
|
||||
r.message.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&message0.GetMessagesByRunIDsResponse{
|
||||
Messages: []*message0.WfMessage{
|
||||
{
|
||||
ID: mID,
|
||||
Role: schema.User,
|
||||
Text: ptr.Of("你好"),
|
||||
},
|
||||
},
|
||||
}, nil).AnyTimes()
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run with interrupt nodes ", func(t *testing.T) {
|
||||
// 生成一个携带 input, 问答文本 问答选项的三个中断节点 做测试
|
||||
id := r.load("chatflow/chat_run_with_interrupt.json", withMode(workflow.WorkflowMode_ChatFlow))
|
||||
r.publish(id, "v0.0.1", true)
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
if e.ID == "3" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "7383997384420262000")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, []*workflow.EnterMessage{
|
||||
{Role: string(schema.User), Content: "input:1", ContentType: "text"},
|
||||
})
|
||||
|
||||
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
if e.ID == "4" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "你好")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, []*workflow.EnterMessage{
|
||||
{Role: string(schema.User), Content: "hello", ContentType: "text"},
|
||||
})
|
||||
|
||||
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
if e.ID == "3" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "question_card_data", "请选择")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, []*workflow.EnterMessage{
|
||||
{Role: string(schema.User), Content: "A", ContentType: "text"},
|
||||
})
|
||||
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
|
||||
if e.ID == "4" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "answer", "A")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@ -221,7 +221,7 @@ const (
|
||||
"id": "5fJt3qKpSz",
|
||||
"name": "list",
|
||||
"defaultValue": [
|
||||
|
||||
|
||||
]
|
||||
}
|
||||
},
|
||||
@ -504,7 +504,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
workflowID = mustParseInt64(req.GetWorkflowID())
|
||||
isDebug = req.GetExecuteMode() == "DEBUG"
|
||||
appID, agentID *int64
|
||||
resolveAppID int64
|
||||
bizID int64
|
||||
conversationID int64
|
||||
sectionID int64
|
||||
version string
|
||||
@ -521,11 +521,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
if req.IsSetAppID() {
|
||||
appID = ptr.Of(mustParseInt64(req.GetAppID()))
|
||||
resolveAppID = mustParseInt64(req.GetAppID())
|
||||
bizID = mustParseInt64(req.GetAppID())
|
||||
}
|
||||
if req.IsSetBotID() {
|
||||
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
|
||||
resolveAppID = mustParseInt64(req.GetBotID())
|
||||
bizID = mustParseInt64(req.GetBotID())
|
||||
}
|
||||
|
||||
if appID != nil && agentID != nil {
|
||||
@ -564,16 +564,16 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
sectionID = cInfo.SectionID
|
||||
|
||||
// only trust the conversation name under the app
|
||||
conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, conversationID)
|
||||
conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existed {
|
||||
return nil, fmt.Errorf("conversation not found")
|
||||
}
|
||||
parameters["CONVERSATION_NAME"] = conversationName
|
||||
parameters[vo.ConversationNameKey] = conversationName
|
||||
} else if req.IsSetConversationID() && req.IsSetBotID() {
|
||||
parameters["CONVERSATION_NAME"] = "Default"
|
||||
parameters[vo.ConversationNameKey] = "Default"
|
||||
conversationID = mustParseInt64(req.GetConversationID())
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
@ -581,11 +581,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
}
|
||||
sectionID = cInfo.SectionID
|
||||
} else {
|
||||
conversationName, ok := parameters["CONVERSATION_NAME"].(string)
|
||||
conversationName, ok := parameters[vo.ConversationNameKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("conversation name is requried")
|
||||
}
|
||||
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
|
||||
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, userID, conversationName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -594,7 +594,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
}
|
||||
|
||||
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: resolveAppID,
|
||||
AgentID: bizID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -606,7 +606,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
roundID := runRecord.ID
|
||||
|
||||
userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
|
||||
userMessage, err := toConversationMessage(ctx, bizID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -648,7 +648,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
return nil, err
|
||||
}
|
||||
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
|
||||
appID: resolveAppID,
|
||||
bizID: bizID,
|
||||
conversationID: conversationID,
|
||||
roundID: roundID,
|
||||
workflowID: workflowID,
|
||||
@ -684,7 +684,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
Cancellable: isDebug,
|
||||
}
|
||||
|
||||
historyMessages, err := makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
|
||||
historyMessages, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -706,7 +706,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
logs.CtxWarnf(ctx, "create history message failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
parameters["USER_INPUT"], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
|
||||
parameters[vo.UserInputKey], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -717,7 +717,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
}
|
||||
|
||||
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
|
||||
appID: resolveAppID,
|
||||
bizID: bizID,
|
||||
conversationID: conversationID,
|
||||
roundID: roundID,
|
||||
workflowID: workflowID,
|
||||
@ -731,7 +731,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlowInfo) func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) {
|
||||
var (
|
||||
appID = info.appID
|
||||
bizID = info.bizID
|
||||
conversationID = info.conversationID
|
||||
roundID = info.roundID
|
||||
workflowID = info.workflowID
|
||||
@ -798,7 +798,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
Role: string(schema.Assistant),
|
||||
Type: "follow_up",
|
||||
ContentType: "text",
|
||||
@ -815,7 +815,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
Status: vo.Completed,
|
||||
ExecuteID: strconv.FormatInt(executeID, 10),
|
||||
Usage: &vo.Usage{
|
||||
@ -929,7 +929,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
}
|
||||
|
||||
_, err = crossmessage.DefaultSVC().Create(ctx, &message.Message{
|
||||
AgentID: appID,
|
||||
AgentID: bizID,
|
||||
RunID: roundID,
|
||||
SectionID: sectionID,
|
||||
Content: msgContent,
|
||||
@ -947,7 +947,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
Role: string(schema.Assistant),
|
||||
Type: string(entity.Answer),
|
||||
ContentType: string(contentType),
|
||||
@ -1046,7 +1046,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
}
|
||||
intermediateMessage = &message.Message{
|
||||
ID: id,
|
||||
AgentID: appID,
|
||||
AgentID: bizID,
|
||||
RunID: roundID,
|
||||
SectionID: sectionID,
|
||||
ConversationID: conversationID,
|
||||
@ -1066,7 +1066,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
Role: string(dataMessage.Role),
|
||||
Type: string(dataMessage.Type),
|
||||
ContentType: string(message.ContentTypeText),
|
||||
@ -1092,7 +1092,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
Role: string(dataMessage.Role),
|
||||
Type: string(dataMessage.Type),
|
||||
ContentType: string(message.ContentTypeText),
|
||||
@ -1155,9 +1155,9 @@ func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message
|
||||
} else {
|
||||
return "", fmt.Errorf("invalid message ccontent type %v", message.ContentType)
|
||||
}
|
||||
|
||||
}
|
||||
func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
||||
|
||||
func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
||||
|
||||
var (
|
||||
rID int64
|
||||
@ -1170,7 +1170,7 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use
|
||||
for _, msg := range messages {
|
||||
if msg.Role == userRole {
|
||||
runRecord, err = crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: appID,
|
||||
AgentID: bizID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -1180,13 +1180,15 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use
|
||||
return nil, err
|
||||
}
|
||||
rID = runRecord.ID
|
||||
} else if msg.Role == assistantRole && rID == 0 {
|
||||
continue
|
||||
} else if msg.Role == assistantRole {
|
||||
if rID == 0 {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid role type %v", msg.Role)
|
||||
}
|
||||
|
||||
m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
|
||||
m, err := toConversationMessage(ctx, bizID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1274,7 +1276,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
||||
func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
||||
type content struct {
|
||||
Type string `json:"type"`
|
||||
FileID *string `json:"file_id"`
|
||||
@ -1284,7 +1286,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec
|
||||
return &message.Message{
|
||||
Role: schema.User,
|
||||
ConversationID: cid,
|
||||
AgentID: appID,
|
||||
AgentID: bizID,
|
||||
RunID: roundID,
|
||||
Content: msg.Content,
|
||||
ContentType: message.ContentTypeText,
|
||||
@ -1304,7 +1306,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec
|
||||
Role: schema.User,
|
||||
MessageType: messageType,
|
||||
ConversationID: cid,
|
||||
AgentID: appID,
|
||||
AgentID: bizID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
RunID: roundID,
|
||||
ContentType: message.ContentTypeMix,
|
||||
@ -1432,7 +1434,7 @@ func toSchemaMessage(ctx context.Context, msg *workflow.EnterMessage) (*schema.M
|
||||
|
||||
type convertToChatFlowInfo struct {
|
||||
userMessage *schema.Message
|
||||
appID int64
|
||||
bizID int64
|
||||
conversationID int64
|
||||
roundID int64
|
||||
workflowID int64
|
||||
|
||||
604
backend/application/workflow/chatflow_test.go
Normal file
604
backend/application/workflow/chatflow_test.go
Normal file
@ -0,0 +1,604 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
messageentity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun/agentrunmock"
|
||||
crossupload "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload/uploadmock"
|
||||
agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
uploadentity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
||||
)
|
||||
|
||||
func TestApplicationService_makeChatFlowUserInput(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *workflow.EnterMessage
|
||||
setupMock func()
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "content type text",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "text",
|
||||
Content: "hello",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: "hello",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "hello world"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: "hello world",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with file",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/file"},
|
||||
}, nil)
|
||||
},
|
||||
expected: "https://example.com/file",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text and file",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "see this file"}, {"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/file"},
|
||||
}, nil)
|
||||
},
|
||||
expected: "see this file,https://example.com/file",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "get file error",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: nil,
|
||||
}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "invalid",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `invalid-json`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
w := &ApplicationService{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := w.makeChatFlowUserInput(ctx, tt.message)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toConversationMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
bizID, cid, userID, roundID, sectionID := int64(2), int64(1), int64(4), int64(3), int64(5)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *workflow.EnterMessage
|
||||
messageType messageentity.MessageType
|
||||
setupMock func()
|
||||
expected *messageentity.Message
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "content type text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "text",
|
||||
Content: "hello",
|
||||
},
|
||||
messageType: messageentity.MessageTypeQuestion,
|
||||
setupMock: func() {},
|
||||
expected: &messageentity.Message{
|
||||
Role: schema.User,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
RunID: roundID,
|
||||
Content: "hello",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "hello"}]`,
|
||||
},
|
||||
messageType: messageentity.MessageTypeQuestion,
|
||||
setupMock: func() {},
|
||||
expected: &messageentity.Message{
|
||||
Role: schema.User,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
RunID: roundID,
|
||||
ContentType: messageentity.ContentTypeMix,
|
||||
MultiContent: []*messageentity.InputMetaData{
|
||||
{Type: messageentity.InputTypeText, Text: "hello"},
|
||||
},
|
||||
SectionID: sectionID,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with file",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
messageType: messageentity.MessageTypeQuestion,
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/file", TosURI: "tos://uri", Name: "file.txt"},
|
||||
}, nil)
|
||||
},
|
||||
expected: &messageentity.Message{
|
||||
Role: schema.User,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
RunID: roundID,
|
||||
ContentType: messageentity.ContentTypeMix,
|
||||
MultiContent: []*messageentity.InputMetaData{
|
||||
{
|
||||
Type: "file",
|
||||
FileData: []*messageentity.FileData{
|
||||
{Url: "https://example.com/file", URI: "tos://uri", Name: "file.txt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
SectionID: sectionID,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "get file error",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "invalid",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: "invalid-json",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid input type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "invalid"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := toConversationMessage(ctx, bizID, cid, userID, roundID, sectionID, tt.messageType, tt.msg)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toSchemaMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *workflow.EnterMessage
|
||||
setupMock func()
|
||||
expected *schema.Message
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "content type text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "text",
|
||||
Content: "hello",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "hello",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "hello"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{Type: schema.ChatMessagePartTypeText, Text: "hello"},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with image",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "image", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/image.png"},
|
||||
}, nil)
|
||||
},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{URL: "https://example.com/image.png"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with various file types",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "1"}, {"type": "audio", "file_id": "2"}, {"type": "video", "file_id": "3"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 1}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/file"}}, nil)
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 2}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/audio"}}, nil)
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 3}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/video"}}, nil)
|
||||
},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{URL: "https://example.com/file"}},
|
||||
{Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{URL: "https://example.com/audio"}},
|
||||
{Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: "https://example.com/video"}},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "get file error",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "invalid",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: "invalid-json",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid input type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "invalid"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := toSchemaMessage(ctx, tt.msg)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_makeChatFlowHistoryMessages(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockAgentRun := agentrunmock.NewMockAgentRun(ctrl)
|
||||
crossagentrun.SetDefaultSVC(mockAgentRun)
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
bizID, conversationID, userID, sectionID, connectorID := int64(2), int64(1), int64(3), int64(4), int64(5)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []*workflow.EnterMessage
|
||||
setupMock func()
|
||||
expected []*messageentity.Message
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty messages",
|
||||
messages: []*workflow.EnterMessage{},
|
||||
setupMock: func() {},
|
||||
expected: []*messageentity.Message{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "one user message",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "text", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
|
||||
},
|
||||
expected: []*messageentity.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
ConversationID: conversationID,
|
||||
AgentID: bizID,
|
||||
RunID: 100,
|
||||
Content: "hello",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "user and assistant message",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "text", Content: "hello"},
|
||||
{Role: "assistant", ContentType: "text", Content: "hi"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
|
||||
},
|
||||
expected: []*messageentity.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
ConversationID: conversationID,
|
||||
AgentID: bizID,
|
||||
RunID: 100,
|
||||
Content: "hello",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
{
|
||||
Role: schema.User,
|
||||
ConversationID: conversationID,
|
||||
AgentID: bizID,
|
||||
RunID: 100,
|
||||
Content: "hi",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeAnswer,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "only assistant message",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "assistant", ContentType: "text", Content: "hi"},
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: []*messageentity.Message{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "create run record error",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "text", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid role",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "system", ContentType: "text", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "toConversationMessage error",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "invalid", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, tt.messages)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -58,7 +58,7 @@ type MessageListRequest struct {
|
||||
BeforeID *string
|
||||
AfterID *string
|
||||
UserID int64
|
||||
AppID int64
|
||||
BizID int64
|
||||
OrderBy *string
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ type WfMessage struct {
|
||||
type GetLatestRunIDsRequest struct {
|
||||
ConversationID int64
|
||||
UserID int64
|
||||
AppID int64
|
||||
BizID int64
|
||||
Rounds int64
|
||||
SectionID int64
|
||||
InitRunID *int64
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
|
||||
var defaultSVC Uploader
|
||||
|
||||
//go:generate mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
|
||||
type Uploader interface {
|
||||
GetFile(ctx context.Context, req *service.GetFileRequest) (resp *service.GetFileResponse, err error)
|
||||
}
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: upload.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
|
||||
//
|
||||
|
||||
// Package uploadmock is a generated GoMock package.
|
||||
package uploadmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
service "github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUploader is a mock of Uploader interface.
|
||||
type MockUploader struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUploaderMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockUploaderMockRecorder is the mock recorder for MockUploader.
|
||||
type MockUploaderMockRecorder struct {
|
||||
mock *MockUploader
|
||||
}
|
||||
|
||||
// NewMockUploader creates a new mock instance.
|
||||
func NewMockUploader(ctrl *gomock.Controller) *MockUploader {
|
||||
mock := &MockUploader{ctrl: ctrl}
|
||||
mock.recorder = &MockUploaderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockUploader) EXPECT() *MockUploaderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetFile mocks base method.
|
||||
func (m *MockUploader) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFile", ctx, req)
|
||||
ret0, _ := ret[0].(*service.GetFileResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFile indicates an expected call of GetFile.
|
||||
func (mr *MockUploaderMockRecorder) GetFile(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploader)(nil).GetFile), ctx, req)
|
||||
}
|
||||
@ -54,7 +54,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
|
||||
ConversationID: req.ConversationID,
|
||||
Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe
|
||||
UserID: strconv.FormatInt(req.UserID, 10),
|
||||
AgentID: req.AppID,
|
||||
AgentID: req.BizID,
|
||||
OrderBy: req.OrderBy,
|
||||
}
|
||||
if req.BeforeID != nil {
|
||||
@ -96,7 +96,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
|
||||
func (c *impl) GetLatestRunIDs(ctx context.Context, req *crossmessage.GetLatestRunIDsRequest) ([]int64, error) {
|
||||
listMeta := &agententity.ListRunRecordMeta{
|
||||
ConversationID: req.ConversationID,
|
||||
AgentID: req.AppID,
|
||||
AgentID: req.BizID,
|
||||
Limit: int32(req.Rounds),
|
||||
SectionID: req.SectionID,
|
||||
}
|
||||
|
||||
@ -75,11 +75,11 @@ type Conversation interface {
|
||||
ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error)
|
||||
ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error
|
||||
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
|
||||
GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error)
|
||||
GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error)
|
||||
UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error)
|
||||
GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error)
|
||||
GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error)
|
||||
GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
|
||||
GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
|
||||
}
|
||||
|
||||
type InterruptEventStore interface {
|
||||
@ -143,8 +143,8 @@ type ConversationRepository interface {
|
||||
UpdateStaticConversation(ctx context.Context, env vo.Env, templateID int64, connectorID int64, userID int64, newConversationID int64) error
|
||||
UpdateDynamicConversation(ctx context.Context, env vo.Env, conversationID, newConversationID int64) error
|
||||
CopyTemplateConversationByAppID(ctx context.Context, appID int64, toAppID int64) error
|
||||
GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
|
||||
GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
|
||||
GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
|
||||
GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
|
||||
}
|
||||
type WorkflowConfig interface {
|
||||
GetNodeOfCodeConfig() *config.NodeOfCodeConfig
|
||||
|
||||
@ -32,6 +32,11 @@ const (
|
||||
ChatFlowMessageCompleted ChatFlowEvent = "conversation.message.completed"
|
||||
)
|
||||
|
||||
const (
|
||||
ConversationNameKey = "CONVERSATION_NAME"
|
||||
UserInputKey = "USER_INPUT"
|
||||
)
|
||||
|
||||
type Usage struct {
|
||||
TokenCount *int32 `form:"token_count" json:"token_count,omitempty"`
|
||||
OutputTokens *int32 `form:"output_count" json:"output_count,omitempty"`
|
||||
|
||||
@ -59,14 +59,14 @@ type ListConversationPolicy struct {
|
||||
}
|
||||
|
||||
type CreateStaticConversation struct {
|
||||
AppID int64
|
||||
BizID int64
|
||||
UserID int64
|
||||
ConnectorID int64
|
||||
|
||||
TemplateID int64
|
||||
}
|
||||
type CreateDynamicConversation struct {
|
||||
AppID int64
|
||||
BizID int64
|
||||
UserID int64
|
||||
ConnectorID int64
|
||||
|
||||
|
||||
@ -235,6 +235,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
||||
if enabled {
|
||||
trimmedSC.GeneratedNodes = append(trimmedSC.GeneratedNodes, ns.Key)
|
||||
}
|
||||
trimmedSC.Init()
|
||||
|
||||
return trimmedSC, nil
|
||||
}
|
||||
|
||||
@ -446,7 +446,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node)
|
||||
|
||||
func parseBatchMode(n *vo.Node) (
|
||||
batchN *vo.Node, // the new batch node
|
||||
enabled bool, // whether the node has enabled batch mode
|
||||
enabled bool, // whether the node has enabled batch mode
|
||||
err error) {
|
||||
if n.Data == nil || n.Data.Inputs == nil {
|
||||
return nil, false, nil
|
||||
|
||||
@ -0,0 +1,275 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 180,
|
||||
"y": 79.2
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "USER_INPUT",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"defaultValue": "dhl"
|
||||
}
|
||||
],
|
||||
"nodeMeta": {
|
||||
"title": "开始",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"subTitle": ""
|
||||
},
|
||||
"trigger_parameters": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 2020,
|
||||
"y": 66.2
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"title": "结束",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"subTitle": ""
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "useAnswerContent",
|
||||
"streamingOutput": true,
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "142077",
|
||||
"name": "optionContent"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "{{output}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "190196",
|
||||
"type": "30",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 640,
|
||||
"y": 78.5
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"nodeMeta": {
|
||||
"title": "输入",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
|
||||
"description": "支持中间过程的信息输入",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "输入"
|
||||
},
|
||||
"inputs": {
|
||||
"outputSchema": "[{\"type\":\"string\",\"name\":\"input\",\"required\":true}]"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "133775",
|
||||
"type": "18",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1100,
|
||||
"y": 39
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"inputs": {
|
||||
"llmParam": {
|
||||
"modelType": 1001,
|
||||
"modelName": "Doubao-Seed-1.6",
|
||||
"generationDiversity": "balance",
|
||||
"temperature": 0.8,
|
||||
"maxTokens": 4096,
|
||||
"topP": 0.7,
|
||||
"responseFormat": 2,
|
||||
"systemPrompt": ""
|
||||
},
|
||||
"inputParameters": [],
|
||||
"extra_output": false,
|
||||
"answer_type": "text",
|
||||
"option_type": "static",
|
||||
"dynamic_option": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"question": "你好",
|
||||
"options": [
|
||||
{
|
||||
"name": ""
|
||||
},
|
||||
{
|
||||
"name": ""
|
||||
}
|
||||
],
|
||||
"limit": 3
|
||||
},
|
||||
"nodeMeta": {
|
||||
"title": "问答",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
"description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
"mainColor": "#3071F2",
|
||||
"subTitle": "问答"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "USER_RESPONSE",
|
||||
"required": true,
|
||||
"description": "用户本轮对话输入内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "142077",
|
||||
"type": "18",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1560,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"inputs": {
|
||||
"llmParam": {
|
||||
"modelType": 1001,
|
||||
"modelName": "Doubao-Seed-1.6",
|
||||
"generationDiversity": "balance",
|
||||
"temperature": 0.8,
|
||||
"maxTokens": 4096,
|
||||
"topP": 0.7,
|
||||
"responseFormat": 2,
|
||||
"systemPrompt": ""
|
||||
},
|
||||
"inputParameters": [],
|
||||
"extra_output": false,
|
||||
"answer_type": "option",
|
||||
"option_type": "static",
|
||||
"dynamic_option": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"question": "请选择",
|
||||
"options": [
|
||||
{
|
||||
"name": "A"
|
||||
},
|
||||
{
|
||||
"name": "B"
|
||||
}
|
||||
],
|
||||
"limit": 3
|
||||
},
|
||||
"nodeMeta": {
|
||||
"title": "问答_1",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
"description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
"mainColor": "#3071F2",
|
||||
"subTitle": "问答"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "optionId",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "optionContent",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "190196"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "142077",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "branch_0"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "142077",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "branch_1"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "142077",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "default"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "190196",
|
||||
"targetNodeID": "133775"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "133775",
|
||||
"targetNodeID": "142077"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -0,0 +1,397 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{output}}",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "123887",
|
||||
"name": "output",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "output"
|
||||
}
|
||||
],
|
||||
"streamingOutput": true,
|
||||
"terminatePlan": "useAnswerContent"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 926,
|
||||
"y": -13
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"fcParamVar": {
|
||||
"knowledgeFCParam": {}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "USER_INPUT",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"llmParam": [
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "1737521813",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modelType"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "豆包·1.5·Pro·32k",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modleName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "balance",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "generationDiversity"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "float",
|
||||
"value": {
|
||||
"content": "0.8",
|
||||
"rawMeta": {
|
||||
"type": 4
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "temperature"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "4096",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "maxTokens"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spCurrentTime"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spAntiLeak"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prefixCache"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "2",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "responseFormat"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{input}}",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "enableChatHistory"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "3",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "chatHistoryRound"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "systemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "stableSystemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "canContinue"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptVersion"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptId"
|
||||
}
|
||||
],
|
||||
"settingOnError": {
|
||||
"processType": 1,
|
||||
"retryTimes": 0,
|
||||
"timeoutMs": 180000
|
||||
}
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "调用大语言模型,使用变量和提示词生成回复",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "大模型",
|
||||
"title": "大模型"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"version": "3"
|
||||
},
|
||||
"edges": null,
|
||||
"id": "123887",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 463,
|
||||
"y": -39
|
||||
}
|
||||
},
|
||||
"type": "3"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "123887",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "123887",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,397 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{output}}",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "123887",
|
||||
"name": "output",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "output"
|
||||
}
|
||||
],
|
||||
"streamingOutput": true,
|
||||
"terminatePlan": "useAnswerContent"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 926,
|
||||
"y": -13
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"fcParamVar": {
|
||||
"knowledgeFCParam": {}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "USER_INPUT",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"llmParam": [
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "1737521813",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modelType"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "豆包·1.5·Pro·32k",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modleName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "balance",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "generationDiversity"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "float",
|
||||
"value": {
|
||||
"content": "0.8",
|
||||
"rawMeta": {
|
||||
"type": 4
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "temperature"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "4096",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "maxTokens"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spCurrentTime"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spAntiLeak"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prefixCache"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "2",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "responseFormat"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{input}}",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": true,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "enableChatHistory"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "3",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "chatHistoryRound"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "systemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "stableSystemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "canContinue"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptVersion"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptId"
|
||||
}
|
||||
],
|
||||
"settingOnError": {
|
||||
"processType": 1,
|
||||
"retryTimes": 0,
|
||||
"timeoutMs": 180000
|
||||
}
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "调用大语言模型,使用变量和提示词生成回复",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "大模型",
|
||||
"title": "大模型"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"version": "3"
|
||||
},
|
||||
"edges": null,
|
||||
"id": "123887",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 463,
|
||||
"y": -39
|
||||
}
|
||||
},
|
||||
"type": "3"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "123887",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "123887",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -148,7 +148,7 @@ func (ch *ConversationHistory) Invoke(ctx context.Context, input map[string]any)
|
||||
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
AppID: *appID,
|
||||
BizID: *appID,
|
||||
Rounds: rounds,
|
||||
InitRunID: initRunID,
|
||||
SectionID: sectionID,
|
||||
|
||||
@ -109,7 +109,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
|
||||
if existed {
|
||||
cID, _, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: ptr.From(appID),
|
||||
BizID: ptr.From(appID),
|
||||
TemplateID: template.TemplateID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
@ -125,7 +125,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
}
|
||||
|
||||
cID, _, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
AppID: ptr.From(appID),
|
||||
BizID: ptr.From(appID),
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Name: conversationName,
|
||||
|
||||
@ -98,7 +98,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env,
|
||||
var conversationID int64
|
||||
if isExist {
|
||||
cID, _, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: ptr.From(appID),
|
||||
BizID: ptr.From(appID),
|
||||
TemplateID: template.TemplateID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
@ -150,7 +150,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
|
||||
var conversationID int64
|
||||
var err error
|
||||
var resolvedAppID int64
|
||||
var bizID int64
|
||||
if appID == nil {
|
||||
if conversationName != "Default" {
|
||||
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
|
||||
@ -167,13 +167,13 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
}, nil
|
||||
}
|
||||
conversationID = *execCtx.ExeCfg.ConversationID
|
||||
resolvedAppID = *agentID
|
||||
bizID = *agentID
|
||||
} else {
|
||||
conversationID, err = c.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolvedAppID = *appID
|
||||
bizID = *appID
|
||||
}
|
||||
|
||||
if conversationID == 0 {
|
||||
@ -209,7 +209,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
if role == "user" {
|
||||
// For user messages, always create a new run and store the ID in the context.
|
||||
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: resolvedAppID,
|
||||
AgentID: bizID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -244,7 +244,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
AppID: resolvedAppID,
|
||||
BizID: bizID,
|
||||
Rounds: 1,
|
||||
})
|
||||
if err != nil {
|
||||
@ -254,7 +254,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
runID = runIDs[0]
|
||||
} else {
|
||||
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: resolvedAppID,
|
||||
AgentID: bizID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -273,7 +273,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
Content: content,
|
||||
ContentType: model.ContentType("text"),
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
AgentID: resolvedAppID,
|
||||
AgentID: bizID,
|
||||
RunID: runID,
|
||||
SectionID: sectionID,
|
||||
}
|
||||
|
||||
@ -115,7 +115,7 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
|
||||
|
||||
var conversationID int64
|
||||
var err error
|
||||
var resolvedAppID int64
|
||||
var bizID int64
|
||||
if appID == nil {
|
||||
if conversationName != "Default" {
|
||||
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
|
||||
@ -129,18 +129,18 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
|
||||
}, nil
|
||||
}
|
||||
conversationID = *execCtx.ExeCfg.ConversationID
|
||||
resolvedAppID = *agentID
|
||||
bizID = *agentID
|
||||
} else {
|
||||
conversationID, err = m.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolvedAppID = *appID
|
||||
bizID = *appID
|
||||
}
|
||||
|
||||
req := &crossmessage.MessageListRequest{
|
||||
UserID: userID,
|
||||
AppID: resolvedAppID,
|
||||
BizID: bizID,
|
||||
ConversationID: conversationID,
|
||||
}
|
||||
|
||||
|
||||
@ -100,9 +100,9 @@ const (
|
||||
ReasoningOutputKey = "reasoning_content"
|
||||
)
|
||||
|
||||
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
|
||||
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
|
||||
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
|
||||
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
|
||||
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
|
||||
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
|
||||
例如:
|
||||
如果内容为<img src="https://example.com/image.jpg">一只小猫,你的输出应为:。
|
||||
如果内容为<img src="https://example.com/image1.jpg">一只小猫 和 <img src="https://example.com/image2.jpg">一只小狗 和 <img src="https://example.com/image3.jpg">一只小牛,你的输出应为: 和  和 
|
||||
@ -290,7 +290,7 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
|
||||
c.AssociateStartNodeUserInputFields = make(map[string]struct{})
|
||||
for _, info := range ns.InputSources {
|
||||
if len(info.Path) == 1 && info.Source.Ref != nil && info.Source.Ref.FromNodeKey == entity.EntryNodeKey {
|
||||
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField("USER_INPUT")) {
|
||||
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField(vo.UserInputKey)) {
|
||||
c.AssociateStartNodeUserInputFields[info.Path[0]] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,8 +192,3 @@ type StreamGenerator interface {
|
||||
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
|
||||
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
|
||||
}
|
||||
|
||||
type ChatHistoryAware interface {
|
||||
ChatHistoryEnabled() bool
|
||||
ChatHistoryRounds() int64
|
||||
}
|
||||
|
||||
@ -432,7 +432,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
|
||||
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
|
||||
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
|
||||
appDynamicConversationDraft.AppID.Eq(meta.AppID),
|
||||
appDynamicConversationDraft.AppID.Eq(meta.BizID),
|
||||
appDynamicConversationDraft.ConnectorID.Eq(meta.ConnectorID),
|
||||
appDynamicConversationDraft.UserID.Eq(meta.UserID),
|
||||
appDynamicConversationDraft.Name.Eq(meta.Name),
|
||||
@ -452,7 +452,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -464,7 +464,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
|
||||
err = r.query.AppDynamicConversationDraft.WithContext(ctx).Create(&model.AppDynamicConversationDraft{
|
||||
ID: id,
|
||||
AppID: meta.AppID,
|
||||
AppID: meta.BizID,
|
||||
Name: meta.Name,
|
||||
UserID: meta.UserID,
|
||||
ConnectorID: meta.ConnectorID,
|
||||
@ -479,7 +479,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
} else if env == vo.Online {
|
||||
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
|
||||
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
|
||||
appDynamicConversationOnline.AppID.Eq(meta.AppID),
|
||||
appDynamicConversationOnline.AppID.Eq(meta.BizID),
|
||||
appDynamicConversationOnline.ConnectorID.Eq(meta.ConnectorID),
|
||||
appDynamicConversationOnline.UserID.Eq(meta.UserID),
|
||||
appDynamicConversationOnline.Name.Eq(meta.Name),
|
||||
@ -498,7 +498,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -509,7 +509,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
|
||||
err = r.query.AppDynamicConversationOnline.WithContext(ctx).Create(&model.AppDynamicConversationOnline{
|
||||
ID: id,
|
||||
AppID: meta.AppID,
|
||||
AppID: meta.BizID,
|
||||
Name: meta.Name,
|
||||
UserID: meta.UserID,
|
||||
ConnectorID: meta.ConnectorID,
|
||||
@ -586,7 +586,7 @@ func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context,
|
||||
return cs[0].ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -627,7 +627,7 @@ func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context
|
||||
return cs[0].ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -841,7 +841,7 @@ func (r *RepositoryImpl) CopyTemplateConversationByAppID(ctx context.Context, ap
|
||||
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
|
||||
if env == vo.Draft {
|
||||
appStaticConversationDraft := r.query.AppStaticConversationDraft
|
||||
ret, err := appStaticConversationDraft.WithContext(ctx).Where(
|
||||
@ -857,7 +857,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
appConversationTemplateDraft := r.query.AppConversationTemplateDraft
|
||||
template, err := appConversationTemplateDraft.WithContext(ctx).Where(
|
||||
appConversationTemplateDraft.TemplateID.Eq(ret.TemplateID),
|
||||
appConversationTemplateDraft.AppID.Eq(appID),
|
||||
appConversationTemplateDraft.AppID.Eq(bizID),
|
||||
).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@ -881,7 +881,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
appConversationTemplateOnline := r.query.AppConversationTemplateOnline
|
||||
template, err := appConversationTemplateOnline.WithContext(ctx).Where(
|
||||
appConversationTemplateOnline.TemplateID.Eq(ret.TemplateID),
|
||||
appConversationTemplateOnline.AppID.Eq(appID),
|
||||
appConversationTemplateOnline.AppID.Eq(bizID),
|
||||
).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@ -894,11 +894,11 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
return "", false, fmt.Errorf("unknown env %v", env)
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
|
||||
func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
|
||||
if env == vo.Draft {
|
||||
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
|
||||
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
|
||||
appDynamicConversationDraft.AppID.Eq(appID),
|
||||
appDynamicConversationDraft.AppID.Eq(bizID),
|
||||
appDynamicConversationDraft.ConnectorID.Eq(connectorID),
|
||||
appDynamicConversationDraft.ConversationID.Eq(conversationID),
|
||||
).First()
|
||||
@ -918,7 +918,7 @@ func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.
|
||||
} else if env == vo.Online {
|
||||
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
|
||||
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
|
||||
appDynamicConversationOnline.AppID.Eq(appID),
|
||||
appDynamicConversationOnline.AppID.Eq(bizID),
|
||||
appDynamicConversationOnline.ConnectorID.Eq(connectorID),
|
||||
appDynamicConversationOnline.ConversationID.Eq(conversationID),
|
||||
).First()
|
||||
|
||||
@ -129,3 +129,8 @@ func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
|
||||
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
|
||||
s.OutputSources = append(s.OutputSources, info...)
|
||||
}
|
||||
|
||||
type ChatHistoryAware interface {
|
||||
ChatHistoryEnabled() bool
|
||||
ChatHistoryRounds() int64
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ type WorkflowSchema struct {
|
||||
compositeNodes []*CompositeNode // won't serialize this
|
||||
requireCheckPoint bool // won't serialize this
|
||||
requireStreaming bool
|
||||
historyRounds int64
|
||||
|
||||
once sync.Once
|
||||
}
|
||||
@ -69,15 +70,22 @@ func (w *WorkflowSchema) Init() {
|
||||
|
||||
w.doGetCompositeNodes()
|
||||
|
||||
historyRounds := int64(0)
|
||||
for _, node := range w.Nodes {
|
||||
if node.Type == entity.NodeTypeSubWorkflow {
|
||||
node.SubWorkflowSchema.Init()
|
||||
historyRounds = max(historyRounds, node.SubWorkflowSchema.HistoryRounds())
|
||||
if node.SubWorkflowSchema.requireCheckPoint {
|
||||
w.requireCheckPoint = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
chatHistoryAware, ok := node.Configs.(ChatHistoryAware)
|
||||
if ok && chatHistoryAware.ChatHistoryEnabled() {
|
||||
historyRounds = max(historyRounds, chatHistoryAware.ChatHistoryRounds())
|
||||
}
|
||||
|
||||
if rc, ok := node.Configs.(RequireCheckpoint); ok {
|
||||
if rc.RequireCheckpoint() {
|
||||
w.requireCheckPoint = true
|
||||
@ -86,6 +94,7 @@ func (w *WorkflowSchema) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
w.historyRounds = historyRounds
|
||||
w.requireStreaming = w.doRequireStreaming()
|
||||
})
|
||||
}
|
||||
@ -122,6 +131,12 @@ func (w *WorkflowSchema) RequireStreaming() bool {
|
||||
return w.requireStreaming
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) HistoryRounds() int64 { return w.historyRounds }
|
||||
|
||||
func (w *WorkflowSchema) SetHistoryRounds(historyRounds int64) {
|
||||
w.historyRounds = historyRounds
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
|
||||
if w.Hierarchy == nil {
|
||||
return nil
|
||||
|
||||
@ -248,7 +248,7 @@ func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Con
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name {
|
||||
if v.Name == vo.ConversationNameKey && v.DefaultValue == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
@ -296,7 +296,7 @@ func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
v.DefaultValue = conversionName
|
||||
}
|
||||
startNode.Data.Outputs[idx] = v
|
||||
@ -351,18 +351,18 @@ func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo
|
||||
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
|
||||
AppID: ptr.Of(appID),
|
||||
AppID: ptr.Of(bizID),
|
||||
Name: ptr.Of(conversationName),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) {
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, bizID int64, userID, connectorID int64) (*conventity.Conversation, error) {
|
||||
return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
|
||||
AgentID: appID,
|
||||
AgentID: bizID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
@ -371,7 +371,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
|
||||
|
||||
if existed {
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: appID,
|
||||
BizID: bizID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
TemplateID: t.TemplateID,
|
||||
@ -383,7 +383,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
|
||||
}
|
||||
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
AppID: appID,
|
||||
BizID: bizID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
Name: conversationName,
|
||||
@ -465,8 +465,8 @@ func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env
|
||||
return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID)
|
||||
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
|
||||
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, bizID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
@ -474,7 +474,7 @@ func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.E
|
||||
return sc, true, nil
|
||||
}
|
||||
|
||||
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID)
|
||||
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, bizID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
@ -22,6 +22,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
@ -282,6 +284,50 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
return executeID, nil
|
||||
}
|
||||
|
||||
func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error {
|
||||
if historyRounds <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if shouldFetchConversationByName {
|
||||
var cID, sID, bizID int64
|
||||
var err error
|
||||
if config.AppID != nil {
|
||||
bizID = *config.AppID
|
||||
} else if config.AgentID != nil {
|
||||
bizID = *config.AgentID
|
||||
}
|
||||
for k, v := range input {
|
||||
if k == vo.ConversationNameKey {
|
||||
cName, ok := v.(string)
|
||||
if !ok {
|
||||
return errors.New("CONVERSATION_NAME must be string")
|
||||
}
|
||||
cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.ConversationID = ptr.Of(cID)
|
||||
config.SectionID = ptr.Of(sID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) {
|
||||
var (
|
||||
err error
|
||||
@ -308,30 +354,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
|
||||
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@ -342,6 +364,17 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds = workflowSC.HistoryRounds()
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID)))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create workflow: %w", err)
|
||||
@ -417,29 +450,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@ -450,6 +460,17 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds = workflowSC.HistoryRounds()
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var wfOpts []compose.WorkflowOption
|
||||
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
@ -997,20 +1018,6 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxHistoryRounds int64 = 30
|
||||
|
||||
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return min(maxRounds, maxHistoryRounds), nil
|
||||
}
|
||||
|
||||
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
|
||||
convID := config.ConversationID
|
||||
agentID := config.AgentID
|
||||
@ -1027,11 +1034,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
var resolvedAppID int64
|
||||
var bizID int64
|
||||
if appID != nil {
|
||||
resolvedAppID = *appID
|
||||
bizID = *appID
|
||||
} else if agentID != nil {
|
||||
resolvedAppID = *agentID
|
||||
bizID = *agentID
|
||||
} else {
|
||||
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
|
||||
return nil, nil, nil
|
||||
@ -1039,7 +1046,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
|
||||
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: *convID,
|
||||
AppID: resolvedAppID,
|
||||
BizID: bizID,
|
||||
UserID: userID,
|
||||
Rounds: historyRounds + 1,
|
||||
SectionID: *sectionID,
|
||||
@ -1048,7 +1055,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
|
||||
return nil, nil, nil
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(runIds) <= 1 {
|
||||
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
|
||||
@ -1061,7 +1068,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
|
||||
return nil, nil, nil
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return response.Messages, response.SchemaMessages, nil
|
||||
|
||||
286
backend/domain/workflow/service/executable_impl_test.go
Normal file
286
backend/domain/workflow/service/executable_impl_test.go
Normal file
@ -0,0 +1,286 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
messagemock "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
mock_workflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestImpl_handleHistory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
defer ctrl.Finish()
|
||||
|
||||
// Setup for cross-domain service mock
|
||||
mockMessage := messagemock.NewMockMessage(ctrl)
|
||||
crossmessage.SetDefaultSVC(mockMessage)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository)
|
||||
config *workflowModel.ExecuteConfig
|
||||
input map[string]any
|
||||
historyRounds int64
|
||||
shouldFetch bool
|
||||
expectErr bool
|
||||
expectedHistory []*crossmessage.WfMessage
|
||||
expectedSchemaHistory []*schema.Message
|
||||
}{
|
||||
{
|
||||
name: "historyRounds is zero",
|
||||
historyRounds: 0,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "shouldFetch is false",
|
||||
historyRounds: 5,
|
||||
shouldFetch: false,
|
||||
config: &workflowModel.ExecuteConfig{
|
||||
AppID: ptr.Of(int64(1)),
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2}, nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
|
||||
Messages: []*crossmessage.WfMessage{{ID: 1}},
|
||||
SchemaMessages: []*schema.Message{{
|
||||
Role: schema.User,
|
||||
Content: "123",
|
||||
}},
|
||||
}, nil).AnyTimes()
|
||||
},
|
||||
expectErr: false,
|
||||
expectedHistory: []*crossmessage.WfMessage{{ID: 1}},
|
||||
expectedSchemaHistory: []*schema.Message{{
|
||||
Role: schema.User,
|
||||
Content: "123",
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "fetch conversation by name - conversation exists",
|
||||
historyRounds: 3,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
|
||||
input: map[string]any{"CONVERSATION_NAME": "test-conv"},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "test-conv").Return(int64(200), int64(201), nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{3, 4}, nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
|
||||
Messages: []*crossmessage.WfMessage{{ID: 2}},
|
||||
SchemaMessages: []*schema.Message{{
|
||||
Role: schema.Assistant,
|
||||
Content: "123",
|
||||
}},
|
||||
}, nil).AnyTimes()
|
||||
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
|
||||
TemplateID: int64(202),
|
||||
SpaceID: int64(203),
|
||||
AppID: int64(204),
|
||||
}, true, nil).AnyTimes()
|
||||
repo.EXPECT().GetOrCreateStaticConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
|
||||
},
|
||||
expectErr: false,
|
||||
expectedHistory: []*crossmessage.WfMessage{{ID: 2}},
|
||||
expectedSchemaHistory: []*schema.Message{{
|
||||
Role: schema.Assistant,
|
||||
Content: "123",
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "fetch conversation by name - conversation not exists",
|
||||
historyRounds: 3,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AgentID: ptr.Of(int64(2))},
|
||||
input: map[string]any{"CONVERSATION_NAME": "new-conv"},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "new-conv").Return(int64(300), int64(301), nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{5, 6}, nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
|
||||
Messages: []*crossmessage.WfMessage{{ID: 3}},
|
||||
}, nil).AnyTimes()
|
||||
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
|
||||
TemplateID: int64(202),
|
||||
SpaceID: int64(203),
|
||||
AppID: int64(204),
|
||||
}, false, nil).AnyTimes()
|
||||
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
|
||||
},
|
||||
expectErr: false,
|
||||
expectedHistory: []*crossmessage.WfMessage{{ID: 3}},
|
||||
},
|
||||
{
|
||||
name: "input with wrong type for conversation name",
|
||||
historyRounds: 5,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
|
||||
input: map[string]any{"CONVERSATION_NAME": 12345},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "GetOrCreateConversation returns error",
|
||||
historyRounds: 5,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
|
||||
input: map[string]any{"CONVERSATION_NAME": "fail-conv"},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "fail-conv").Return(int64(0), int64(0), errors.New("db error")).AnyTimes()
|
||||
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
|
||||
TemplateID: int64(202),
|
||||
SpaceID: int64(203),
|
||||
AppID: int64(204),
|
||||
}, false, nil).AnyTimes()
|
||||
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, errors.New("db error")).AnyTimes()
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockService := mock_workflow.NewMockService(ctrl)
|
||||
mockRepo := mock_workflow.NewMockRepository(ctrl)
|
||||
testImpl := &impl{repo: mockRepo, conversationImpl: &conversationImpl{repo: mockRepo}}
|
||||
|
||||
tt.setupMock(mockService, mockMessage, mockRepo)
|
||||
|
||||
err := testImpl.handleHistory(ctx, tt.config, tt.input, tt.historyRounds, tt.shouldFetch)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedHistory != nil {
|
||||
assert.Equal(t, tt.expectedHistory, tt.config.ConversationHistory)
|
||||
} else if tt.historyRounds == 0 {
|
||||
assert.Nil(t, tt.config.ConversationHistory)
|
||||
} else if tt.expectedSchemaHistory != nil {
|
||||
assert.Equal(t, tt.expectedSchemaHistory, tt.config.ConversationHistorySchemaMessages)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImpl_prefetchChatHistory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockMessage := messagemock.NewMockMessage(ctrl)
|
||||
crossmessage.SetDefaultSVC(mockMessage)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(msgSvc *messagemock.MockMessage)
|
||||
config workflowModel.ExecuteConfig
|
||||
historyRounds int64
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "SectionID is nil",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
AppID: ptr.Of(int64(1)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "ConversationID is nil",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
AppID: ptr.Of(int64(1)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "AppID and AgentID are both nil",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "GetLatestRunIDs returns error",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
AppID: ptr.Of(int64(1)),
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "GetMessagesByRunIDs returns error",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
AppID: ptr.Of(int64(1)),
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2, 3}, nil)
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testImpl := &impl{}
|
||||
tt.setupMock(mockMessage)
|
||||
|
||||
_, _, err := testImpl.prefetchChatHistory(ctx, tt.config, tt.historyRounds)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -39,7 +39,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
@ -522,7 +521,7 @@ func isEnableChatHistory(s *schema.NodeSchema) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
chatHistoryAware, ok := s.Configs.(nodes.ChatHistoryAware)
|
||||
chatHistoryAware, ok := s.Configs.(schema.ChatHistoryAware)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@ -2171,15 +2170,15 @@ func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error {
|
||||
vMap[v.Name] = true
|
||||
}
|
||||
|
||||
if _, ok := vMap["USER_INPUT"]; !ok {
|
||||
if _, ok := vMap[vo.UserInputKey]; !ok {
|
||||
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
|
||||
Name: "USER_INPUT",
|
||||
Name: vo.UserInputKey,
|
||||
Type: vo.VariableTypeString,
|
||||
})
|
||||
}
|
||||
if _, ok := vMap["CONVERSATION_NAME"]; !ok {
|
||||
if _, ok := vMap[vo.ConversationNameKey]; !ok {
|
||||
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
|
||||
Name: "CONVERSATION_NAME",
|
||||
Name: vo.ConversationNameKey,
|
||||
Type: vo.VariableTypeString,
|
||||
DefaultValue: "Default",
|
||||
})
|
||||
|
||||
@ -22,15 +22,11 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@ -201,214 +197,3 @@ func isIncremental(prev version, next version) bool {
|
||||
|
||||
return next.Patch > prev.Patch
|
||||
}
|
||||
|
||||
func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) {
|
||||
visited := make(map[string]struct{})
|
||||
maxRounds := int64(0)
|
||||
err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds)
|
||||
return maxRounds, err
|
||||
}
|
||||
|
||||
func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
|
||||
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
return nil
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err)
|
||||
}
|
||||
|
||||
return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds)
|
||||
}
|
||||
|
||||
func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
|
||||
for _, node := range nodes {
|
||||
if node == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory {
|
||||
if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds {
|
||||
*maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound
|
||||
}
|
||||
} else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil {
|
||||
param := node.Data.Inputs.LLMParam
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return err
|
||||
}
|
||||
var chatHistoryEnabled bool
|
||||
var chatHistoryRound int64
|
||||
for _, param := range llmParam {
|
||||
switch param.Name {
|
||||
case "enableChatHistory":
|
||||
if val, ok := param.Input.Value.Content.(bool); ok {
|
||||
b := val
|
||||
chatHistoryEnabled = b
|
||||
}
|
||||
case "chatHistoryRound":
|
||||
if strVal, ok := param.Input.Value.Content.(string); ok {
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatHistoryRound = int64Val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chatHistoryEnabled {
|
||||
if chatHistoryRound > *maxRounds {
|
||||
*maxRounds = chatHistoryRound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil
|
||||
if isSubWorkflow {
|
||||
workflowIDStr := node.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
|
||||
}
|
||||
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: node.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
|
||||
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(node.Blocks) > 0 {
|
||||
if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
visited := make(map[string]struct{})
|
||||
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
return 0, nil
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
maxRounds := int64(0)
|
||||
c := &vo.Canvas{}
|
||||
if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
}
|
||||
var (
|
||||
n *vo.Node
|
||||
nodeFinder func(nodes []*vo.Node) *vo.Node
|
||||
)
|
||||
nodeFinder = func(nodes []*vo.Node) *vo.Node {
|
||||
for i := range nodes {
|
||||
if nodes[i].ID == nodeID {
|
||||
return nodes[i]
|
||||
}
|
||||
if len(nodes[i].Blocks) > 0 {
|
||||
if n := nodeFinder(nodes[i].Blocks); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
n = nodeFinder(c.Nodes)
|
||||
if n.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if n.Data == nil || n.Data.Inputs == nil {
|
||||
return 0, nil
|
||||
}
|
||||
param := n.Data.Inputs.LLMParam
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var chatHistoryEnabled bool
|
||||
var chatHistoryRound int64
|
||||
for _, param := range llmParam {
|
||||
switch param.Name {
|
||||
case "enableChatHistory":
|
||||
if val, ok := param.Input.Value.Content.(bool); ok {
|
||||
b := val
|
||||
chatHistoryEnabled = b
|
||||
}
|
||||
case "chatHistoryRound":
|
||||
if strVal, ok := param.Input.Value.Content.(string); ok {
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
chatHistoryRound = int64Val
|
||||
}
|
||||
}
|
||||
}
|
||||
if chatHistoryEnabled {
|
||||
return chatHistoryRound, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() {
|
||||
if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory {
|
||||
return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
if n.Data != nil && n.Data.Inputs != nil {
|
||||
workflowIDStr := n.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
return 0, nil
|
||||
}
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err)
|
||||
}
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: n.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxRounds, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(n.Blocks) > 0 {
|
||||
if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return maxRounds, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user