diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go
index 3806995bb..353975a3e 100644
--- a/backend/api/handler/coze/workflow_service_test.go
+++ b/backend/api/handler/coze/workflow_service_test.go
@@ -228,6 +228,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
h.POST("/api/workflow_api/chat_flow_role/delete", DeleteChatFlowRole)
h.POST("/api/workflow_api/chat_flow_role/create", CreateChatFlowRole)
h.GET("/api/workflow_api/chat_flow_role/get", GetChatFlowRole)
+ h.POST("/v1/workflows/chat", OpenAPIChatFlowRun)
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
mockIDGen := mock.NewMockIDGenerator(ctrl)
@@ -1082,6 +1083,46 @@ func (r *wfTestRunner) openapiResume(id string, eventID string, resumeData strin
return re
}
+func (r *wfTestRunner) openapiChatFlowRun(wfID string, cID, appID, botID *string, input any, additionalMessage []*workflow.EnterMessage) *sse.Reader {
+ inputStr, _ := sonic.MarshalString(input)
+
+ req := &workflow.ChatFlowRunRequest{
+ WorkflowID: wfID,
+ Parameters: ptr.Of(inputStr),
+ AdditionalMessages: additionalMessage,
+ }
+ if cID != nil {
+ req.ConversationID = cID
+ }
+ if appID != nil {
+ req.AppID = appID
+ }
+ if botID != nil {
+ req.BotID = botID
+ }
+
+ m, err := sonic.Marshal(req)
+ assert.NoError(r.t, err)
+
+ c, _ := client.NewClient()
+ hReq, hResp := protocol.AcquireRequest(), protocol.AcquireResponse()
+ hReq.SetRequestURI("http://localhost:8888" + "/v1/workflows/chat")
+ hReq.SetMethod("POST")
+ hReq.SetBody(m)
+ hReq.SetHeader("Content-Type", "application/json")
+ err = c.Do(context.Background(), hReq, hResp)
+ assert.NoError(r.t, err)
+
+ if hResp.StatusCode() != http.StatusOK {
+ r.t.Errorf("unexpected status code: %d, body: %s", hResp.StatusCode(), string(hResp.Body()))
+ }
+
+ re, err := sse.NewReader(hResp)
+ assert.NoError(r.t, err)
+
+ return re
+}
+
func (r *wfTestRunner) runServer() func() {
go func() {
_ = r.h.Run()
@@ -5491,7 +5532,7 @@ func TestConversationOfChatFlow(t *testing.T) {
if err != nil {
return err
}
- if v.Name == "CONVERSATION_NAME" {
+ if v.Name == vo.ConversationNameKey {
v.DefaultValue = cName
}
startNode.Data.Outputs[idx] = v
@@ -5522,7 +5563,7 @@ func TestConversationOfChatFlow(t *testing.T) {
for _, vAny := range node.Data.Outputs {
v, err := vo.ParseVariable(vAny)
assert.NoError(t, err)
- if v.Name == "CONVERSATION_NAME" {
+ if v.Name == vo.ConversationNameKey {
assert.Equal(t, v.DefaultValue, updateName)
}
}
@@ -5569,7 +5610,7 @@ func TestConversationOfChatFlow(t *testing.T) {
for _, vAny := range node.Data.Outputs {
v, err := vo.ParseVariable(vAny)
assert.NoError(t, err)
- if v.Name == "CONVERSATION_NAME" {
+ if v.Name == vo.ConversationNameKey {
assert.Equal(t, v.DefaultValue, cName+"copy")
}
}
@@ -5988,3 +6029,223 @@ func TestConversationHistoryNodes(t *testing.T) {
assert.Equal(t, []any{}, outputMap["history_list"])
})
}
+
+func TestChatFlowRun(t *testing.T) {
+ mockey.PatchConvey("chat flow run", t, func() {
+ r := newWfTestRunner(t)
+ appworkflow.SVC.IDGenerator = r.idGen
+ defer r.closeFn()
+ defer r.runServer()()
+
+ chatModel1 := &testutil.UTChatModel{
+ StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
+ sr := schema.StreamReaderFromArray([]*schema.Message{
+ {
+ Role: schema.Assistant,
+ Content: "I ",
+ },
+ {
+ Role: schema.Assistant,
+ Content: "don't know.",
+ },
+ })
+ return sr, nil
+ },
+ }
+ r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel1, nil, nil).AnyTimes()
+
+ id := r.load("chatflow/llm_chat.json", withMode(workflow.WorkflowMode_ChatFlow))
+ r.publish(id, "v0.0.1", true)
+ cID := time.Now().UnixNano()
+ cIDStr := strconv.FormatInt(cID, 10)
+ appID := time.Now().UnixNano()
+ appIDStr := strconv.FormatInt(appID, 10)
+
+ // Create conversation first
+ r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
+ ID: cID,
+ }, nil).AnyTimes()
+ idStr := r.load("conversation_manager/update_dynamic_conversation.json")
+ r.publish(idStr, "v0.0.1", true)
+ ret, _ := r.openapiSyncRun(idStr, map[string]string{
+ "input": "v1",
+ "new_name": "v2",
+ }, withRunProjectID(appID))
+ assert.Equal(t, map[string]any{"conversationId": strconv.FormatInt(cID, 10), "isExisted": false, "isSuccess": true}, ret["obj"])
+
+ msg := []*workflow.EnterMessage{
+ {
+ Role: "user",
+ ContentType: "text",
+ Content: "你好",
+ },
+ }
+ sID := time.Now().UnixNano()
+ r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
+ ID: cID,
+ SectionID: sID,
+ }, nil).AnyTimes()
+ rID := time.Now().UnixNano()
+ r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
+ ID: rID,
+ }, nil).AnyTimes()
+ mID := time.Now().Unix()
+ r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
+ ID: mID,
+ }, nil).AnyTimes()
+
+ t.Run("chat flow run in app", func(t *testing.T) {
+ sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, msg)
+ err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ return nil
+ })
+ assert.NoError(t, err)
+ })
+
+ t.Run("chat flow run in bot", func(t *testing.T) {
+ botID := time.Now().UnixNano()
+ botIDStr := strconv.FormatInt(botID, 10)
+ sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), nil, ptr.Of(botIDStr), map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, msg)
+ err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ return nil
+ })
+ assert.NoError(t, err)
+ })
+
+ t.Run("chat flow run without cID", func(t *testing.T) {
+ sseReader := r.openapiChatFlowRun(id, nil, ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, msg)
+ err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ return nil
+ })
+ assert.NoError(t, err)
+ })
+
+ t.Run("chat flow run with additional messages", func(t *testing.T) {
+ additionalMsg := []*workflow.EnterMessage{
+ {
+ Role: "user",
+ ContentType: "text",
+ Content: "你好, 我叫小明",
+ },
+ {
+ Role: "assistant",
+ ContentType: "text",
+ Content: "你好小明, 很高兴认识你",
+ },
+ {
+ Role: "user",
+ ContentType: "text",
+ Content: "你好",
+ },
+ }
+ sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, additionalMsg)
+ err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ return nil
+ })
+ assert.NoError(t, err)
+ })
+
+ t.Run("chat flow run with history messages", func(t *testing.T) {
+ id := r.load("chatflow/llm_chat_with_history.json", withMode(workflow.WorkflowMode_ChatFlow))
+ r.publish(id, "v0.0.1", true)
+ r.message.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{rID}, nil).AnyTimes()
+ r.message.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&message0.GetMessagesByRunIDsResponse{
+ Messages: []*message0.WfMessage{
+ {
+ ID: mID,
+ Role: schema.User,
+ Text: ptr.Of("你好"),
+ },
+ },
+ }, nil).AnyTimes()
+ sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, msg)
+ err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ return nil
+ })
+ assert.NoError(t, err)
+ })
+
+ t.Run("chat flow run with interrupt nodes ", func(t *testing.T) {
+ // 生成一个携带 input, 问答文本 问答选项的三个中断节点 做测试
+ id := r.load("chatflow/chat_run_with_interrupt.json", withMode(workflow.WorkflowMode_ChatFlow))
+ r.publish(id, "v0.0.1", true)
+ sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, msg)
+
+ err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ if e.ID == "3" {
+ assert.Equal(t, e.Type, "conversation.message.completed")
+ assert.Contains(t, string(e.Data), "7383997384420262000")
+ }
+ return nil
+ })
+ assert.NoError(t, err)
+
+ sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, []*workflow.EnterMessage{
+ {Role: string(schema.User), Content: "input:1", ContentType: "text"},
+ })
+
+ err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+ if e.ID == "4" {
+ assert.Equal(t, e.Type, "conversation.message.completed")
+ assert.Contains(t, string(e.Data), "你好")
+ }
+ return nil
+ })
+ assert.NoError(t, err)
+
+ sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, []*workflow.EnterMessage{
+ {Role: string(schema.User), Content: "hello", ContentType: "text"},
+ })
+
+ err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ if e.ID == "3" {
+ assert.Equal(t, e.Type, "conversation.message.completed")
+ assert.Contains(t, string(e.Data), "question_card_data", "请选择")
+ }
+ return nil
+ })
+ assert.NoError(t, err)
+
+ sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
+ vo.ConversationNameKey: "Default",
+ }, []*workflow.EnterMessage{
+ {Role: string(schema.User), Content: "A", ContentType: "text"},
+ })
+ err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
+ t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
+
+ if e.ID == "4" {
+ assert.Equal(t, e.Type, "conversation.message.completed")
+ assert.Contains(t, string(e.Data), "answer", "A")
+ }
+ return nil
+ })
+ assert.NoError(t, err)
+
+ })
+
+ })
+}
diff --git a/backend/application/workflow/chatflow.go b/backend/application/workflow/chatflow.go
index ef9d43252..137282498 100644
--- a/backend/application/workflow/chatflow.go
+++ b/backend/application/workflow/chatflow.go
@@ -221,7 +221,7 @@ const (
"id": "5fJt3qKpSz",
"name": "list",
"defaultValue": [
-
+
]
}
},
@@ -504,7 +504,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
workflowID = mustParseInt64(req.GetWorkflowID())
isDebug = req.GetExecuteMode() == "DEBUG"
appID, agentID *int64
- resolveAppID int64
+ bizID int64
conversationID int64
sectionID int64
version string
@@ -521,11 +521,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
if req.IsSetAppID() {
appID = ptr.Of(mustParseInt64(req.GetAppID()))
- resolveAppID = mustParseInt64(req.GetAppID())
+ bizID = mustParseInt64(req.GetAppID())
}
if req.IsSetBotID() {
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
- resolveAppID = mustParseInt64(req.GetBotID())
+ bizID = mustParseInt64(req.GetBotID())
}
if appID != nil && agentID != nil {
@@ -564,16 +564,16 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
sectionID = cInfo.SectionID
// only trust the conversation name under the app
- conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, conversationID)
+ conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, conversationID)
if err != nil {
return nil, err
}
if !existed {
return nil, fmt.Errorf("conversation not found")
}
- parameters["CONVERSATION_NAME"] = conversationName
+ parameters[vo.ConversationNameKey] = conversationName
} else if req.IsSetConversationID() && req.IsSetBotID() {
- parameters["CONVERSATION_NAME"] = "Default"
+ parameters[vo.ConversationNameKey] = "Default"
conversationID = mustParseInt64(req.GetConversationID())
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, conversationID)
if err != nil {
@@ -581,11 +581,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
}
sectionID = cInfo.SectionID
} else {
- conversationName, ok := parameters["CONVERSATION_NAME"].(string)
+ conversationName, ok := parameters[vo.ConversationNameKey].(string)
if !ok {
return nil, fmt.Errorf("conversation name is requried")
}
- cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
+ cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, userID, conversationName)
if err != nil {
return nil, err
}
@@ -594,7 +594,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
}
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
- AgentID: resolveAppID,
+ AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@@ -606,7 +606,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
roundID := runRecord.ID
- userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
+ userMessage, err := toConversationMessage(ctx, bizID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
if err != nil {
return nil, err
}
@@ -648,7 +648,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
return nil, err
}
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
- appID: resolveAppID,
+ bizID: bizID,
conversationID: conversationID,
roundID: roundID,
workflowID: workflowID,
@@ -684,7 +684,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
Cancellable: isDebug,
}
- historyMessages, err := makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
+ historyMessages, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
if err != nil {
return nil, err
}
@@ -706,7 +706,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
logs.CtxWarnf(ctx, "create history message failed, err=%v", err)
}
}
- parameters["USER_INPUT"], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
+ parameters[vo.UserInputKey], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
if err != nil {
return nil, err
}
@@ -717,7 +717,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
}
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
- appID: resolveAppID,
+ bizID: bizID,
conversationID: conversationID,
roundID: roundID,
workflowID: workflowID,
@@ -731,7 +731,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlowInfo) func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) {
var (
- appID = info.appID
+ bizID = info.bizID
conversationID = info.conversationID
roundID = info.roundID
workflowID = info.workflowID
@@ -798,7 +798,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
- BotID: strconv.FormatInt(appID, 10),
+ BotID: strconv.FormatInt(bizID, 10),
Role: string(schema.Assistant),
Type: "follow_up",
ContentType: "text",
@@ -815,7 +815,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
- BotID: strconv.FormatInt(appID, 10),
+ BotID: strconv.FormatInt(bizID, 10),
Status: vo.Completed,
ExecuteID: strconv.FormatInt(executeID, 10),
Usage: &vo.Usage{
@@ -929,7 +929,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
}
_, err = crossmessage.DefaultSVC().Create(ctx, &message.Message{
- AgentID: appID,
+ AgentID: bizID,
RunID: roundID,
SectionID: sectionID,
Content: msgContent,
@@ -947,7 +947,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
- BotID: strconv.FormatInt(appID, 10),
+ BotID: strconv.FormatInt(bizID, 10),
Role: string(schema.Assistant),
Type: string(entity.Answer),
ContentType: string(contentType),
@@ -1046,7 +1046,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
}
intermediateMessage = &message.Message{
ID: id,
- AgentID: appID,
+ AgentID: bizID,
RunID: roundID,
SectionID: sectionID,
ConversationID: conversationID,
@@ -1066,7 +1066,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
- BotID: strconv.FormatInt(appID, 10),
+ BotID: strconv.FormatInt(bizID, 10),
Role: string(dataMessage.Role),
Type: string(dataMessage.Type),
ContentType: string(message.ContentTypeText),
@@ -1092,7 +1092,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
- BotID: strconv.FormatInt(appID, 10),
+ BotID: strconv.FormatInt(bizID, 10),
Role: string(dataMessage.Role),
Type: string(dataMessage.Type),
ContentType: string(message.ContentTypeText),
@@ -1155,9 +1155,9 @@ func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message
} else {
return "", fmt.Errorf("invalid message ccontent type %v", message.ContentType)
}
-
}
-func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
+
+func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
var (
rID int64
@@ -1170,7 +1170,7 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use
for _, msg := range messages {
if msg.Role == userRole {
runRecord, err = crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
- AgentID: appID,
+ AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@@ -1180,13 +1180,15 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use
return nil, err
}
rID = runRecord.ID
- } else if msg.Role == assistantRole && rID == 0 {
- continue
+ } else if msg.Role == assistantRole {
+ if rID == 0 {
+ continue
+ }
} else {
return nil, fmt.Errorf("invalid role type %v", msg.Role)
}
- m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
+ m, err := toConversationMessage(ctx, bizID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
if err != nil {
return nil, err
}
@@ -1274,7 +1276,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req
}, nil
}
-func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
+func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
type content struct {
Type string `json:"type"`
FileID *string `json:"file_id"`
@@ -1284,7 +1286,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec
return &message.Message{
Role: schema.User,
ConversationID: cid,
- AgentID: appID,
+ AgentID: bizID,
RunID: roundID,
Content: msg.Content,
ContentType: message.ContentTypeText,
@@ -1304,7 +1306,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec
Role: schema.User,
MessageType: messageType,
ConversationID: cid,
- AgentID: appID,
+ AgentID: bizID,
UserID: strconv.FormatInt(userID, 10),
RunID: roundID,
ContentType: message.ContentTypeMix,
@@ -1432,7 +1434,7 @@ func toSchemaMessage(ctx context.Context, msg *workflow.EnterMessage) (*schema.M
type convertToChatFlowInfo struct {
userMessage *schema.Message
- appID int64
+ bizID int64
conversationID int64
roundID int64
workflowID int64
diff --git a/backend/application/workflow/chatflow_test.go b/backend/application/workflow/chatflow_test.go
new file mode 100644
index 000000000..1bcd4a5b1
--- /dev/null
+++ b/backend/application/workflow/chatflow_test.go
@@ -0,0 +1,604 @@
+/*
+ * Copyright 2025 coze-dev Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package workflow
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "testing"
+
+ "github.com/cloudwego/eino/schema"
+ "github.com/stretchr/testify/assert"
+ "go.uber.org/mock/gomock"
+
+ messageentity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
+ "github.com/coze-dev/coze-studio/backend/api/model/workflow"
+ crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
+ "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun/agentrunmock"
+ crossupload "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload"
+ "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload/uploadmock"
+ agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
+ uploadentity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
+ "github.com/coze-dev/coze-studio/backend/domain/upload/service"
+)
+
+func TestApplicationService_makeChatFlowUserInput(t *testing.T) {
+ ctx := context.Background()
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockUpload := uploadmock.NewMockUploader(ctrl)
+ crossupload.SetDefaultSVC(mockUpload)
+
+ tests := []struct {
+ name string
+ message *workflow.EnterMessage
+ setupMock func()
+ expected string
+ expectErr bool
+ }{
+ {
+ name: "content type text",
+ message: &workflow.EnterMessage{
+ ContentType: "text",
+ Content: "hello",
+ },
+ setupMock: func() {},
+ expected: "hello",
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with text",
+ message: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "text", "text": "hello world"}]`,
+ },
+ setupMock: func() {},
+ expected: "hello world",
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with file",
+ message: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
+ File: &uploadentity.File{Url: "https://example.com/file"},
+ }, nil)
+ },
+ expected: "https://example.com/file",
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with text and file",
+ message: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "text", "text": "see this file"}, {"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
+ File: &uploadentity.File{Url: "https://example.com/file"},
+ }, nil)
+ },
+ expected: "see this file,https://example.com/file",
+ expectErr: false,
+ },
+ {
+ name: "get file error",
+ message: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
+ },
+ expectErr: true,
+ },
+ {
+ name: "file not found",
+ message: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
+ File: nil,
+ }, nil)
+ },
+ expectErr: true,
+ },
+ {
+ name: "invalid content type",
+ message: &workflow.EnterMessage{
+ ContentType: "invalid",
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ {
+ name: "invalid json",
+ message: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `invalid-json`,
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ }
+
+ w := &ApplicationService{}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.setupMock()
+ result, err := w.makeChatFlowUserInput(ctx, tt.message)
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func Test_toConversationMessage(t *testing.T) {
+ ctx := context.Background()
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockUpload := uploadmock.NewMockUploader(ctrl)
+ crossupload.SetDefaultSVC(mockUpload)
+
+ bizID, cid, userID, roundID, sectionID := int64(2), int64(1), int64(4), int64(3), int64(5)
+
+ tests := []struct {
+ name string
+ msg *workflow.EnterMessage
+ messageType messageentity.MessageType
+ setupMock func()
+ expected *messageentity.Message
+ expectErr bool
+ }{
+ {
+ name: "content type text",
+ msg: &workflow.EnterMessage{
+ ContentType: "text",
+ Content: "hello",
+ },
+ messageType: messageentity.MessageTypeQuestion,
+ setupMock: func() {},
+ expected: &messageentity.Message{
+ Role: schema.User,
+ ConversationID: cid,
+ AgentID: bizID,
+ RunID: roundID,
+ Content: "hello",
+ ContentType: messageentity.ContentTypeText,
+ MessageType: messageentity.MessageTypeQuestion,
+ UserID: strconv.FormatInt(userID, 10),
+ SectionID: sectionID,
+ },
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with text",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "text", "text": "hello"}]`,
+ },
+ messageType: messageentity.MessageTypeQuestion,
+ setupMock: func() {},
+ expected: &messageentity.Message{
+ Role: schema.User,
+ MessageType: messageentity.MessageTypeQuestion,
+ ConversationID: cid,
+ AgentID: bizID,
+ UserID: strconv.FormatInt(userID, 10),
+ RunID: roundID,
+ ContentType: messageentity.ContentTypeMix,
+ MultiContent: []*messageentity.InputMetaData{
+ {Type: messageentity.InputTypeText, Text: "hello"},
+ },
+ SectionID: sectionID,
+ },
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with file",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ messageType: messageentity.MessageTypeQuestion,
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
+ File: &uploadentity.File{Url: "https://example.com/file", TosURI: "tos://uri", Name: "file.txt"},
+ }, nil)
+ },
+ expected: &messageentity.Message{
+ Role: schema.User,
+ MessageType: messageentity.MessageTypeQuestion,
+ ConversationID: cid,
+ AgentID: bizID,
+ UserID: strconv.FormatInt(userID, 10),
+ RunID: roundID,
+ ContentType: messageentity.ContentTypeMix,
+ MultiContent: []*messageentity.InputMetaData{
+ {
+ Type: "file",
+ FileData: []*messageentity.FileData{
+ {Url: "https://example.com/file", URI: "tos://uri", Name: "file.txt"},
+ },
+ },
+ },
+ SectionID: sectionID,
+ },
+ expectErr: false,
+ },
+ {
+ name: "get file error",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
+ },
+ expectErr: true,
+ },
+ {
+ name: "file not found",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
+ },
+ expectErr: true,
+ },
+ {
+ name: "invalid content type",
+ msg: &workflow.EnterMessage{
+ ContentType: "invalid",
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ {
+ name: "invalid json",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: "invalid-json",
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ {
+ name: "invalid input type",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "invalid"}]`,
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.setupMock()
+ result, err := toConversationMessage(ctx, bizID, cid, userID, roundID, sectionID, tt.messageType, tt.msg)
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func Test_toSchemaMessage(t *testing.T) {
+ ctx := context.Background()
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockUpload := uploadmock.NewMockUploader(ctrl)
+ crossupload.SetDefaultSVC(mockUpload)
+
+ tests := []struct {
+ name string
+ msg *workflow.EnterMessage
+ setupMock func()
+ expected *schema.Message
+ expectErr bool
+ }{
+ {
+ name: "content type text",
+ msg: &workflow.EnterMessage{
+ ContentType: "text",
+ Content: "hello",
+ },
+ setupMock: func() {},
+ expected: &schema.Message{
+ Role: schema.User,
+ Content: "hello",
+ },
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with text",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "text", "text": "hello"}]`,
+ },
+ setupMock: func() {},
+ expected: &schema.Message{
+ Role: schema.User,
+ MultiContent: []schema.ChatMessagePart{
+ {Type: schema.ChatMessagePartTypeText, Text: "hello"},
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with image",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "image", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
+ File: &uploadentity.File{Url: "https://example.com/image.png"},
+ }, nil)
+ },
+ expected: &schema.Message{
+ Role: schema.User,
+ MultiContent: []schema.ChatMessagePart{
+ {
+ Type: schema.ChatMessagePartTypeImageURL,
+ ImageURL: &schema.ChatMessageImageURL{URL: "https://example.com/image.png"},
+ },
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "content type object_string with various file types",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "1"}, {"type": "audio", "file_id": "2"}, {"type": "video", "file_id": "3"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 1}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/file"}}, nil)
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 2}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/audio"}}, nil)
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 3}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/video"}}, nil)
+ },
+ expected: &schema.Message{
+ Role: schema.User,
+ MultiContent: []schema.ChatMessagePart{
+ {Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{URL: "https://example.com/file"}},
+ {Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{URL: "https://example.com/audio"}},
+ {Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: "https://example.com/video"}},
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "get file error",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
+ },
+ expectErr: true,
+ },
+ {
+ name: "file not found",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "file", "file_id": "123"}]`,
+ },
+ setupMock: func() {
+ mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
+ },
+ expectErr: true,
+ },
+ {
+ name: "invalid content type",
+ msg: &workflow.EnterMessage{
+ ContentType: "invalid",
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ {
+ name: "invalid json",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: "invalid-json",
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ {
+ name: "invalid input type",
+ msg: &workflow.EnterMessage{
+ ContentType: "object_string",
+ Content: `[{"type": "invalid"}]`,
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.setupMock()
+ result, err := toSchemaMessage(ctx, tt.msg)
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func Test_makeChatFlowHistoryMessages(t *testing.T) {
+ ctx := context.Background()
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockAgentRun := agentrunmock.NewMockAgentRun(ctrl)
+ crossagentrun.SetDefaultSVC(mockAgentRun)
+ mockUpload := uploadmock.NewMockUploader(ctrl)
+ crossupload.SetDefaultSVC(mockUpload)
+
+ bizID, conversationID, userID, sectionID, connectorID := int64(2), int64(1), int64(3), int64(4), int64(5)
+
+ tests := []struct {
+ name string
+ messages []*workflow.EnterMessage
+ setupMock func()
+ expected []*messageentity.Message
+ expectErr bool
+ }{
+ {
+ name: "empty messages",
+ messages: []*workflow.EnterMessage{},
+ setupMock: func() {},
+ expected: []*messageentity.Message{},
+ expectErr: false,
+ },
+ {
+ name: "one user message",
+ messages: []*workflow.EnterMessage{
+ {Role: "user", ContentType: "text", Content: "hello"},
+ },
+ setupMock: func() {
+ mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
+ },
+ expected: []*messageentity.Message{
+ {
+ Role: schema.User,
+ ConversationID: conversationID,
+ AgentID: bizID,
+ RunID: 100,
+ Content: "hello",
+ ContentType: messageentity.ContentTypeText,
+ MessageType: messageentity.MessageTypeQuestion,
+ UserID: strconv.FormatInt(userID, 10),
+ SectionID: sectionID,
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "user and assistant message",
+ messages: []*workflow.EnterMessage{
+ {Role: "user", ContentType: "text", Content: "hello"},
+ {Role: "assistant", ContentType: "text", Content: "hi"},
+ },
+ setupMock: func() {
+ mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
+ },
+ expected: []*messageentity.Message{
+ {
+ Role: schema.User,
+ ConversationID: conversationID,
+ AgentID: bizID,
+ RunID: 100,
+ Content: "hello",
+ ContentType: messageentity.ContentTypeText,
+ MessageType: messageentity.MessageTypeQuestion,
+ UserID: strconv.FormatInt(userID, 10),
+ SectionID: sectionID,
+ },
+ {
+ Role: schema.User,
+ ConversationID: conversationID,
+ AgentID: bizID,
+ RunID: 100,
+ Content: "hi",
+ ContentType: messageentity.ContentTypeText,
+ MessageType: messageentity.MessageTypeAnswer,
+ UserID: strconv.FormatInt(userID, 10),
+ SectionID: sectionID,
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "only assistant message",
+ messages: []*workflow.EnterMessage{
+ {Role: "assistant", ContentType: "text", Content: "hi"},
+ },
+ setupMock: func() {},
+ expected: []*messageentity.Message{},
+ expectErr: false,
+ },
+ {
+ name: "create run record error",
+ messages: []*workflow.EnterMessage{
+ {Role: "user", ContentType: "text", Content: "hello"},
+ },
+ setupMock: func() {
+ mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
+ },
+ expectErr: true,
+ },
+ {
+ name: "invalid role",
+ messages: []*workflow.EnterMessage{
+ {Role: "system", ContentType: "text", Content: "hello"},
+ },
+ setupMock: func() {},
+ expectErr: true,
+ },
+ {
+ name: "toConversationMessage error",
+ messages: []*workflow.EnterMessage{
+ {Role: "user", ContentType: "invalid", Content: "hello"},
+ },
+ setupMock: func() {
+ mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil)
+ },
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.setupMock()
+ result, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, tt.messages)
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
diff --git a/backend/crossdomain/contract/message/message.go b/backend/crossdomain/contract/message/message.go
index 0a3d564ec..20ad431ce 100644
--- a/backend/crossdomain/contract/message/message.go
+++ b/backend/crossdomain/contract/message/message.go
@@ -58,7 +58,7 @@ type MessageListRequest struct {
BeforeID *string
AfterID *string
UserID int64
- AppID int64
+ BizID int64
OrderBy *string
}
@@ -88,7 +88,7 @@ type WfMessage struct {
type GetLatestRunIDsRequest struct {
ConversationID int64
UserID int64
- AppID int64
+ BizID int64
Rounds int64
SectionID int64
InitRunID *int64
diff --git a/backend/crossdomain/contract/upload/upload.go b/backend/crossdomain/contract/upload/upload.go
index 1c29e1e50..2be76435d 100644
--- a/backend/crossdomain/contract/upload/upload.go
+++ b/backend/crossdomain/contract/upload/upload.go
@@ -24,6 +24,7 @@ import (
var defaultSVC Uploader
+//go:generate mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
type Uploader interface {
GetFile(ctx context.Context, req *service.GetFileRequest) (resp *service.GetFileResponse, err error)
}
diff --git a/backend/crossdomain/contract/upload/uploadmock/upload_mock.go b/backend/crossdomain/contract/upload/uploadmock/upload_mock.go
new file mode 100644
index 000000000..f094a580c
--- /dev/null
+++ b/backend/crossdomain/contract/upload/uploadmock/upload_mock.go
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2025 coze-dev Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Code generated by MockGen. DO NOT EDIT.
+// Source: upload.go
+//
+// Generated by this command:
+//
+// mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
+//
+
+// Package uploadmock is a generated GoMock package.
+package uploadmock
+
+import (
+ context "context"
+ reflect "reflect"
+
+ service "github.com/coze-dev/coze-studio/backend/domain/upload/service"
+ gomock "go.uber.org/mock/gomock"
+)
+
+// MockUploader is a mock of Uploader interface.
+type MockUploader struct {
+ ctrl *gomock.Controller
+ recorder *MockUploaderMockRecorder
+ isgomock struct{}
+}
+
+// MockUploaderMockRecorder is the mock recorder for MockUploader.
+type MockUploaderMockRecorder struct {
+ mock *MockUploader
+}
+
+// NewMockUploader creates a new mock instance.
+func NewMockUploader(ctrl *gomock.Controller) *MockUploader {
+ mock := &MockUploader{ctrl: ctrl}
+ mock.recorder = &MockUploaderMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockUploader) EXPECT() *MockUploaderMockRecorder {
+ return m.recorder
+}
+
+// GetFile mocks base method.
+func (m *MockUploader) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetFile", ctx, req)
+ ret0, _ := ret[0].(*service.GetFileResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetFile indicates an expected call of GetFile.
+func (mr *MockUploaderMockRecorder) GetFile(ctx, req any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploader)(nil).GetFile), ctx, req)
+}
diff --git a/backend/crossdomain/impl/message/message.go b/backend/crossdomain/impl/message/message.go
index 9fd7fdf9a..2bd0b217d 100644
--- a/backend/crossdomain/impl/message/message.go
+++ b/backend/crossdomain/impl/message/message.go
@@ -54,7 +54,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
ConversationID: req.ConversationID,
Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe
UserID: strconv.FormatInt(req.UserID, 10),
- AgentID: req.AppID,
+ AgentID: req.BizID,
OrderBy: req.OrderBy,
}
if req.BeforeID != nil {
@@ -96,7 +96,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
func (c *impl) GetLatestRunIDs(ctx context.Context, req *crossmessage.GetLatestRunIDsRequest) ([]int64, error) {
listMeta := &agententity.ListRunRecordMeta{
ConversationID: req.ConversationID,
- AgentID: req.AppID,
+ AgentID: req.BizID,
Limit: int32(req.Rounds),
SectionID: req.SectionID,
}
diff --git a/backend/domain/workflow/component_interface.go b/backend/domain/workflow/component_interface.go
index 9ebc924f4..9e40bc3ca 100644
--- a/backend/domain/workflow/component_interface.go
+++ b/backend/domain/workflow/component_interface.go
@@ -75,11 +75,11 @@ type Conversation interface {
ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error)
ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
- GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error)
+ GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error)
UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error)
GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error)
GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error)
- GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
+ GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
}
type InterruptEventStore interface {
@@ -143,8 +143,8 @@ type ConversationRepository interface {
UpdateStaticConversation(ctx context.Context, env vo.Env, templateID int64, connectorID int64, userID int64, newConversationID int64) error
UpdateDynamicConversation(ctx context.Context, env vo.Env, conversationID, newConversationID int64) error
CopyTemplateConversationByAppID(ctx context.Context, appID int64, toAppID int64) error
- GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
- GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
+ GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
+ GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
}
type WorkflowConfig interface {
GetNodeOfCodeConfig() *config.NodeOfCodeConfig
diff --git a/backend/domain/workflow/entity/vo/chatflow.go b/backend/domain/workflow/entity/vo/chatflow.go
index 133b7a6f0..4bb1c5e9e 100644
--- a/backend/domain/workflow/entity/vo/chatflow.go
+++ b/backend/domain/workflow/entity/vo/chatflow.go
@@ -32,6 +32,11 @@ const (
ChatFlowMessageCompleted ChatFlowEvent = "conversation.message.completed"
)
+const (
+ ConversationNameKey = "CONVERSATION_NAME"
+ UserInputKey = "USER_INPUT"
+)
+
type Usage struct {
TokenCount *int32 `form:"token_count" json:"token_count,omitempty"`
OutputTokens *int32 `form:"output_count" json:"output_count,omitempty"`
diff --git a/backend/domain/workflow/entity/vo/conversation.go b/backend/domain/workflow/entity/vo/conversation.go
index 1ed5c739a..dbd238ad2 100644
--- a/backend/domain/workflow/entity/vo/conversation.go
+++ b/backend/domain/workflow/entity/vo/conversation.go
@@ -59,14 +59,14 @@ type ListConversationPolicy struct {
}
type CreateStaticConversation struct {
- AppID int64
+ BizID int64
UserID int64
ConnectorID int64
TemplateID int64
}
type CreateDynamicConversation struct {
- AppID int64
+ BizID int64
UserID int64
ConnectorID int64
diff --git a/backend/domain/workflow/internal/canvas/adaptor/from_node.go b/backend/domain/workflow/internal/canvas/adaptor/from_node.go
index f93b45ede..3fdfe8e8c 100644
--- a/backend/domain/workflow/internal/canvas/adaptor/from_node.go
+++ b/backend/domain/workflow/internal/canvas/adaptor/from_node.go
@@ -235,6 +235,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
if enabled {
trimmedSC.GeneratedNodes = append(trimmedSC.GeneratedNodes, ns.Key)
}
+ trimmedSC.Init()
return trimmedSC, nil
}
diff --git a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go
index 63aa47902..d8d875610 100644
--- a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go
+++ b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go
@@ -446,7 +446,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node)
func parseBatchMode(n *vo.Node) (
batchN *vo.Node, // the new batch node
- enabled bool, // whether the node has enabled batch mode
+ enabled bool, // whether the node has enabled batch mode
err error) {
if n.Data == nil || n.Data.Inputs == nil {
return nil, false, nil
diff --git a/backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json b/backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json
new file mode 100644
index 000000000..dcf277ac6
--- /dev/null
+++ b/backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json
@@ -0,0 +1,275 @@
+{
+ "nodes": [
+ {
+ "id": "100001",
+ "type": "1",
+ "meta": {
+ "position": {
+ "x": 180,
+ "y": 79.2
+ }
+ },
+ "data": {
+ "outputs": [
+ {
+ "type": "string",
+ "name": "USER_INPUT",
+ "required": false
+ },
+ {
+ "type": "string",
+ "name": "CONVERSATION_NAME",
+ "required": false,
+ "description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
+ "defaultValue": "dhl"
+ }
+ ],
+ "nodeMeta": {
+ "title": "开始",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
+ "description": "工作流的起始节点,用于设定启动工作流需要的信息",
+ "subTitle": ""
+ },
+ "trigger_parameters": []
+ }
+ },
+ {
+ "id": "900001",
+ "type": "2",
+ "meta": {
+ "position": {
+ "x": 2020,
+ "y": 66.2
+ }
+ },
+ "data": {
+ "nodeMeta": {
+ "title": "结束",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
+ "description": "工作流的最终节点,用于返回工作流运行后的结果信息",
+ "subTitle": ""
+ },
+ "inputs": {
+ "terminatePlan": "useAnswerContent",
+ "streamingOutput": true,
+ "inputParameters": [
+ {
+ "name": "output",
+ "input": {
+ "type": "string",
+ "value": {
+ "type": "ref",
+ "content": {
+ "source": "block-output",
+ "blockID": "142077",
+ "name": "optionContent"
+ },
+ "rawMeta": {
+ "type": 1
+ }
+ }
+ }
+ }
+ ],
+ "content": {
+ "type": "string",
+ "value": {
+ "type": "literal",
+ "content": "{{output}}"
+ }
+ }
+ }
+ }
+ },
+ {
+ "id": "190196",
+ "type": "30",
+ "meta": {
+ "position": {
+ "x": 640,
+ "y": 78.5
+ }
+ },
+ "data": {
+ "outputs": [
+ {
+ "type": "string",
+ "name": "input",
+ "required": true
+ }
+ ],
+ "nodeMeta": {
+ "title": "输入",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
+ "description": "支持中间过程的信息输入",
+ "mainColor": "#5C62FF",
+ "subTitle": "输入"
+ },
+ "inputs": {
+ "outputSchema": "[{\"type\":\"string\",\"name\":\"input\",\"required\":true}]"
+ }
+ }
+ },
+ {
+ "id": "133775",
+ "type": "18",
+ "meta": {
+ "position": {
+ "x": 1100,
+ "y": 39
+ }
+ },
+ "data": {
+ "inputs": {
+ "llmParam": {
+ "modelType": 1001,
+ "modelName": "Doubao-Seed-1.6",
+ "generationDiversity": "balance",
+ "temperature": 0.8,
+ "maxTokens": 4096,
+ "topP": 0.7,
+ "responseFormat": 2,
+ "systemPrompt": ""
+ },
+ "inputParameters": [],
+ "extra_output": false,
+ "answer_type": "text",
+ "option_type": "static",
+ "dynamic_option": {
+ "type": "string",
+ "value": {
+ "type": "ref",
+ "content": {
+ "source": "block-output",
+ "blockID": "",
+ "name": ""
+ }
+ }
+ },
+ "question": "你好",
+ "options": [
+ {
+ "name": ""
+ },
+ {
+ "name": ""
+ }
+ ],
+ "limit": 3
+ },
+ "nodeMeta": {
+ "title": "问答",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
+ "description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
+ "mainColor": "#3071F2",
+ "subTitle": "问答"
+ },
+ "outputs": [
+ {
+ "type": "string",
+ "name": "USER_RESPONSE",
+ "required": true,
+ "description": "用户本轮对话输入内容"
+ }
+ ]
+ }
+ },
+ {
+ "id": "142077",
+ "type": "18",
+ "meta": {
+ "position": {
+ "x": 1560,
+ "y": 0
+ }
+ },
+ "data": {
+ "inputs": {
+ "llmParam": {
+ "modelType": 1001,
+ "modelName": "Doubao-Seed-1.6",
+ "generationDiversity": "balance",
+ "temperature": 0.8,
+ "maxTokens": 4096,
+ "topP": 0.7,
+ "responseFormat": 2,
+ "systemPrompt": ""
+ },
+ "inputParameters": [],
+ "extra_output": false,
+ "answer_type": "option",
+ "option_type": "static",
+ "dynamic_option": {
+ "type": "string",
+ "value": {
+ "type": "ref",
+ "content": {
+ "source": "block-output",
+ "blockID": "",
+ "name": ""
+ }
+ }
+ },
+ "question": "请选择",
+ "options": [
+ {
+ "name": "A"
+ },
+ {
+ "name": "B"
+ }
+ ],
+ "limit": 3
+ },
+ "nodeMeta": {
+ "title": "问答_1",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
+ "description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
+ "mainColor": "#3071F2",
+ "subTitle": "问答"
+ },
+ "outputs": [
+ {
+ "type": "string",
+ "name": "optionId",
+ "required": false
+ },
+ {
+ "type": "string",
+ "name": "optionContent",
+ "required": false
+ }
+ ]
+ }
+ }
+ ],
+ "edges": [
+ {
+ "sourceNodeID": "100001",
+ "targetNodeID": "190196"
+ },
+ {
+ "sourceNodeID": "142077",
+ "targetNodeID": "900001",
+ "sourcePortID": "branch_0"
+ },
+ {
+ "sourceNodeID": "142077",
+ "targetNodeID": "900001",
+ "sourcePortID": "branch_1"
+ },
+ {
+ "sourceNodeID": "142077",
+ "targetNodeID": "900001",
+ "sourcePortID": "default"
+ },
+ {
+ "sourceNodeID": "190196",
+ "targetNodeID": "133775"
+ },
+ {
+ "sourceNodeID": "133775",
+ "targetNodeID": "142077"
+ }
+ ]
+}
diff --git a/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json
new file mode 100644
index 000000000..7b28b91be
--- /dev/null
+++ b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json
@@ -0,0 +1,397 @@
+{
+ "nodes": [
+ {
+ "blocks": [],
+ "data": {
+ "nodeMeta": {
+ "description": "工作流的起始节点,用于设定启动工作流需要的信息",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
+ "subTitle": "",
+ "title": "开始"
+ },
+ "outputs": [
+ {
+ "name": "USER_INPUT",
+ "required": false,
+ "type": "string"
+ },
+ {
+ "defaultValue": "Default",
+ "description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
+ "name": "CONVERSATION_NAME",
+ "required": false,
+ "type": "string"
+ }
+ ],
+ "trigger_parameters": []
+ },
+ "edges": null,
+ "id": "100001",
+ "meta": {
+ "position": {
+ "x": 0,
+ "y": 0
+ }
+ },
+ "type": "1"
+ },
+ {
+ "blocks": [],
+ "data": {
+ "inputs": {
+ "content": {
+ "type": "string",
+ "value": {
+ "content": "{{output}}",
+ "type": "literal"
+ }
+ },
+ "inputParameters": [
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": {
+ "blockID": "123887",
+ "name": "output",
+ "source": "block-output"
+ },
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "ref"
+ }
+ },
+ "name": "output"
+ }
+ ],
+ "streamingOutput": true,
+ "terminatePlan": "useAnswerContent"
+ },
+ "nodeMeta": {
+ "description": "工作流的最终节点,用于返回工作流运行后的结果信息",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
+ "subTitle": "",
+ "title": "结束"
+ }
+ },
+ "edges": null,
+ "id": "900001",
+ "meta": {
+ "position": {
+ "x": 926,
+ "y": -13
+ }
+ },
+ "type": "2"
+ },
+ {
+ "blocks": [],
+ "data": {
+ "inputs": {
+ "fcParamVar": {
+ "knowledgeFCParam": {}
+ },
+ "inputParameters": [
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": {
+ "blockID": "100001",
+ "name": "USER_INPUT",
+ "source": "block-output"
+ },
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "ref"
+ }
+ },
+ "name": "input"
+ }
+ ],
+ "llmParam": [
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "1737521813",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "modelType"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "豆包·1.5·Pro·32k",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "modleName"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "balance",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "generationDiversity"
+ },
+ {
+ "input": {
+ "type": "float",
+ "value": {
+ "content": "0.8",
+ "rawMeta": {
+ "type": 4
+ },
+ "type": "literal"
+ }
+ },
+ "name": "temperature"
+ },
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "4096",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "maxTokens"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "spCurrentTime"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "spAntiLeak"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "prefixCache"
+ },
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "2",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "responseFormat"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "{{input}}",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "prompt"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "enableChatHistory"
+ },
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "3",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "chatHistoryRound"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "systemPrompt"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "stableSystemPrompt"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "canContinue"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "loopPromptVersion"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "loopPromptName"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "loopPromptId"
+ }
+ ],
+ "settingOnError": {
+ "processType": 1,
+ "retryTimes": 0,
+ "timeoutMs": 180000
+ }
+ },
+ "nodeMeta": {
+ "description": "调用大语言模型,使用变量和提示词生成回复",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
+ "mainColor": "#5C62FF",
+ "subTitle": "大模型",
+ "title": "大模型"
+ },
+ "outputs": [
+ {
+ "name": "output",
+ "type": "string"
+ }
+ ],
+ "version": "3"
+ },
+ "edges": null,
+ "id": "123887",
+ "meta": {
+ "position": {
+ "x": 463,
+ "y": -39
+ }
+ },
+ "type": "3"
+ }
+ ],
+ "edges": [
+ {
+ "sourceNodeID": "100001",
+ "targetNodeID": "123887",
+ "sourcePortID": ""
+ },
+ {
+ "sourceNodeID": "123887",
+ "targetNodeID": "900001",
+ "sourcePortID": ""
+ }
+ ],
+ "versions": {
+ "loop": "v2"
+ }
+}
diff --git a/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json
new file mode 100644
index 000000000..c88945037
--- /dev/null
+++ b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json
@@ -0,0 +1,397 @@
+{
+ "nodes": [
+ {
+ "blocks": [],
+ "data": {
+ "nodeMeta": {
+ "description": "工作流的起始节点,用于设定启动工作流需要的信息",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
+ "subTitle": "",
+ "title": "开始"
+ },
+ "outputs": [
+ {
+ "name": "USER_INPUT",
+ "required": false,
+ "type": "string"
+ },
+ {
+ "defaultValue": "Default",
+ "description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
+ "name": "CONVERSATION_NAME",
+ "required": false,
+ "type": "string"
+ }
+ ],
+ "trigger_parameters": []
+ },
+ "edges": null,
+ "id": "100001",
+ "meta": {
+ "position": {
+ "x": 0,
+ "y": 0
+ }
+ },
+ "type": "1"
+ },
+ {
+ "blocks": [],
+ "data": {
+ "inputs": {
+ "content": {
+ "type": "string",
+ "value": {
+ "content": "{{output}}",
+ "type": "literal"
+ }
+ },
+ "inputParameters": [
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": {
+ "blockID": "123887",
+ "name": "output",
+ "source": "block-output"
+ },
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "ref"
+ }
+ },
+ "name": "output"
+ }
+ ],
+ "streamingOutput": true,
+ "terminatePlan": "useAnswerContent"
+ },
+ "nodeMeta": {
+ "description": "工作流的最终节点,用于返回工作流运行后的结果信息",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
+ "subTitle": "",
+ "title": "结束"
+ }
+ },
+ "edges": null,
+ "id": "900001",
+ "meta": {
+ "position": {
+ "x": 926,
+ "y": -13
+ }
+ },
+ "type": "2"
+ },
+ {
+ "blocks": [],
+ "data": {
+ "inputs": {
+ "fcParamVar": {
+ "knowledgeFCParam": {}
+ },
+ "inputParameters": [
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": {
+ "blockID": "100001",
+ "name": "USER_INPUT",
+ "source": "block-output"
+ },
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "ref"
+ }
+ },
+ "name": "input"
+ }
+ ],
+ "llmParam": [
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "1737521813",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "modelType"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "豆包·1.5·Pro·32k",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "modleName"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "balance",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "generationDiversity"
+ },
+ {
+ "input": {
+ "type": "float",
+ "value": {
+ "content": "0.8",
+ "rawMeta": {
+ "type": 4
+ },
+ "type": "literal"
+ }
+ },
+ "name": "temperature"
+ },
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "4096",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "maxTokens"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "spCurrentTime"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "spAntiLeak"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "prefixCache"
+ },
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "2",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "responseFormat"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "{{input}}",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "prompt"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": true,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "enableChatHistory"
+ },
+ {
+ "input": {
+ "type": "integer",
+ "value": {
+ "content": "3",
+ "rawMeta": {
+ "type": 2
+ },
+ "type": "literal"
+ }
+ },
+ "name": "chatHistoryRound"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "systemPrompt"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "stableSystemPrompt"
+ },
+ {
+ "input": {
+ "type": "boolean",
+ "value": {
+ "content": false,
+ "rawMeta": {
+ "type": 3
+ },
+ "type": "literal"
+ }
+ },
+ "name": "canContinue"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "loopPromptVersion"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "loopPromptName"
+ },
+ {
+ "input": {
+ "type": "string",
+ "value": {
+ "content": "",
+ "rawMeta": {
+ "type": 1
+ },
+ "type": "literal"
+ }
+ },
+ "name": "loopPromptId"
+ }
+ ],
+ "settingOnError": {
+ "processType": 1,
+ "retryTimes": 0,
+ "timeoutMs": 180000
+ }
+ },
+ "nodeMeta": {
+ "description": "调用大语言模型,使用变量和提示词生成回复",
+ "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
+ "mainColor": "#5C62FF",
+ "subTitle": "大模型",
+ "title": "大模型"
+ },
+ "outputs": [
+ {
+ "name": "output",
+ "type": "string"
+ }
+ ],
+ "version": "3"
+ },
+ "edges": null,
+ "id": "123887",
+ "meta": {
+ "position": {
+ "x": 463,
+ "y": -39
+ }
+ },
+ "type": "3"
+ }
+ ],
+ "edges": [
+ {
+ "sourceNodeID": "100001",
+ "targetNodeID": "123887",
+ "sourcePortID": ""
+ },
+ {
+ "sourceNodeID": "123887",
+ "targetNodeID": "900001",
+ "sourcePortID": ""
+ }
+ ],
+ "versions": {
+ "loop": "v2"
+ }
+}
diff --git a/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go b/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go
index 35ef1b2a5..91d3d10ad 100644
--- a/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go
+++ b/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go
@@ -148,7 +148,7 @@ func (ch *ConversationHistory) Invoke(ctx context.Context, input map[string]any)
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
ConversationID: conversationID,
UserID: userID,
- AppID: *appID,
+ BizID: *appID,
Rounds: rounds,
InitRunID: initRunID,
SectionID: sectionID,
diff --git a/backend/domain/workflow/internal/nodes/conversation/createconversation.go b/backend/domain/workflow/internal/nodes/conversation/createconversation.go
index 671dcb498..2b3c71295 100644
--- a/backend/domain/workflow/internal/nodes/conversation/createconversation.go
+++ b/backend/domain/workflow/internal/nodes/conversation/createconversation.go
@@ -109,7 +109,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
if existed {
cID, _, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
- AppID: ptr.From(appID),
+ BizID: ptr.From(appID),
TemplateID: template.TemplateID,
UserID: userID,
ConnectorID: connectorID,
@@ -125,7 +125,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
}
cID, _, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
- AppID: ptr.From(appID),
+ BizID: ptr.From(appID),
UserID: userID,
ConnectorID: connectorID,
Name: conversationName,
diff --git a/backend/domain/workflow/internal/nodes/conversation/createmessage.go b/backend/domain/workflow/internal/nodes/conversation/createmessage.go
index 16b038e48..0579a7406 100644
--- a/backend/domain/workflow/internal/nodes/conversation/createmessage.go
+++ b/backend/domain/workflow/internal/nodes/conversation/createmessage.go
@@ -98,7 +98,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env,
var conversationID int64
if isExist {
cID, _, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
- AppID: ptr.From(appID),
+ BizID: ptr.From(appID),
TemplateID: template.TemplateID,
UserID: userID,
ConnectorID: connectorID,
@@ -150,7 +150,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
var conversationID int64
var err error
- var resolvedAppID int64
+ var bizID int64
if appID == nil {
if conversationName != "Default" {
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
@@ -167,13 +167,13 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
}, nil
}
conversationID = *execCtx.ExeCfg.ConversationID
- resolvedAppID = *agentID
+ bizID = *agentID
} else {
conversationID, err = c.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
if err != nil {
return nil, err
}
- resolvedAppID = *appID
+ bizID = *appID
}
if conversationID == 0 {
@@ -209,7 +209,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
if role == "user" {
// For user messages, always create a new run and store the ID in the context.
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
- AgentID: resolvedAppID,
+ AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@@ -244,7 +244,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
ConversationID: conversationID,
UserID: userID,
- AppID: resolvedAppID,
+ BizID: bizID,
Rounds: 1,
})
if err != nil {
@@ -254,7 +254,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
runID = runIDs[0]
} else {
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
- AgentID: resolvedAppID,
+ AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@@ -273,7 +273,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
Content: content,
ContentType: model.ContentType("text"),
UserID: strconv.FormatInt(userID, 10),
- AgentID: resolvedAppID,
+ AgentID: bizID,
RunID: runID,
SectionID: sectionID,
}
diff --git a/backend/domain/workflow/internal/nodes/conversation/messagelist.go b/backend/domain/workflow/internal/nodes/conversation/messagelist.go
index be50af48b..db8611fcf 100644
--- a/backend/domain/workflow/internal/nodes/conversation/messagelist.go
+++ b/backend/domain/workflow/internal/nodes/conversation/messagelist.go
@@ -115,7 +115,7 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
var conversationID int64
var err error
- var resolvedAppID int64
+ var bizID int64
if appID == nil {
if conversationName != "Default" {
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
@@ -129,18 +129,18 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
}, nil
}
conversationID = *execCtx.ExeCfg.ConversationID
- resolvedAppID = *agentID
+ bizID = *agentID
} else {
conversationID, err = m.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
if err != nil {
return nil, err
}
- resolvedAppID = *appID
+ bizID = *appID
}
req := &crossmessage.MessageListRequest{
UserID: userID,
- AppID: resolvedAppID,
+ BizID: bizID,
ConversationID: conversationID,
}
diff --git a/backend/domain/workflow/internal/nodes/llm/llm.go b/backend/domain/workflow/internal/nodes/llm/llm.go
index b1652eab4..b45f29a6c 100644
--- a/backend/domain/workflow/internal/nodes/llm/llm.go
+++ b/backend/domain/workflow/internal/nodes/llm/llm.go
@@ -100,9 +100,9 @@ const (
ReasoningOutputKey = "reasoning_content"
)
-const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
- 1.如果引用的内容里面包含
的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
- 2.如果引用的内容不包含
的标签, 你回答问题时不需要展示图片 。
+const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
+ 1.如果引用的内容里面包含
的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
+ 2.如果引用的内容不包含
的标签, 你回答问题时不需要展示图片 。
例如:
如果内容为
一只小猫,你的输出应为:。
如果内容为
一只小猫 和
一只小狗 和
一只小牛,你的输出应为: 和  和 
@@ -290,7 +290,7 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
c.AssociateStartNodeUserInputFields = make(map[string]struct{})
for _, info := range ns.InputSources {
if len(info.Path) == 1 && info.Source.Ref != nil && info.Source.Ref.FromNodeKey == entity.EntryNodeKey {
- if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField("USER_INPUT")) {
+ if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField(vo.UserInputKey)) {
c.AssociateStartNodeUserInputFields[info.Path[0]] = struct{}{}
}
}
diff --git a/backend/domain/workflow/internal/nodes/node.go b/backend/domain/workflow/internal/nodes/node.go
index 634da62f5..575123fe1 100644
--- a/backend/domain/workflow/internal/nodes/node.go
+++ b/backend/domain/workflow/internal/nodes/node.go
@@ -192,8 +192,3 @@ type StreamGenerator interface {
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
}
-
-type ChatHistoryAware interface {
- ChatHistoryEnabled() bool
- ChatHistoryRounds() int64
-}
diff --git a/backend/domain/workflow/internal/repo/conversation_repository.go b/backend/domain/workflow/internal/repo/conversation_repository.go
index a3ae07bea..85d75d947 100644
--- a/backend/domain/workflow/internal/repo/conversation_repository.go
+++ b/backend/domain/workflow/internal/repo/conversation_repository.go
@@ -432,7 +432,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
- appDynamicConversationDraft.AppID.Eq(meta.AppID),
+ appDynamicConversationDraft.AppID.Eq(meta.BizID),
appDynamicConversationDraft.ConnectorID.Eq(meta.ConnectorID),
appDynamicConversationDraft.UserID.Eq(meta.UserID),
appDynamicConversationDraft.Name.Eq(meta.Name),
@@ -452,7 +452,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
}
- conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
+ conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@@ -464,7 +464,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
err = r.query.AppDynamicConversationDraft.WithContext(ctx).Create(&model.AppDynamicConversationDraft{
ID: id,
- AppID: meta.AppID,
+ AppID: meta.BizID,
Name: meta.Name,
UserID: meta.UserID,
ConnectorID: meta.ConnectorID,
@@ -479,7 +479,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
} else if env == vo.Online {
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
- appDynamicConversationOnline.AppID.Eq(meta.AppID),
+ appDynamicConversationOnline.AppID.Eq(meta.BizID),
appDynamicConversationOnline.ConnectorID.Eq(meta.ConnectorID),
appDynamicConversationOnline.UserID.Eq(meta.UserID),
appDynamicConversationOnline.Name.Eq(meta.Name),
@@ -498,7 +498,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
}
- conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
+ conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@@ -509,7 +509,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
err = r.query.AppDynamicConversationOnline.WithContext(ctx).Create(&model.AppDynamicConversationOnline{
ID: id,
- AppID: meta.AppID,
+ AppID: meta.BizID,
Name: meta.Name,
UserID: meta.UserID,
ConnectorID: meta.ConnectorID,
@@ -586,7 +586,7 @@ func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context,
return cs[0].ConversationID, cInfo.SectionID, true, nil
}
- conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
+ conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@@ -627,7 +627,7 @@ func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context
return cs[0].ConversationID, cInfo.SectionID, true, nil
}
- conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
+ conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@@ -841,7 +841,7 @@ func (r *RepositoryImpl) CopyTemplateConversationByAppID(ctx context.Context, ap
}
-func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
+func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
if env == vo.Draft {
appStaticConversationDraft := r.query.AppStaticConversationDraft
ret, err := appStaticConversationDraft.WithContext(ctx).Where(
@@ -857,7 +857,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
appConversationTemplateDraft := r.query.AppConversationTemplateDraft
template, err := appConversationTemplateDraft.WithContext(ctx).Where(
appConversationTemplateDraft.TemplateID.Eq(ret.TemplateID),
- appConversationTemplateDraft.AppID.Eq(appID),
+ appConversationTemplateDraft.AppID.Eq(bizID),
).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -881,7 +881,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
appConversationTemplateOnline := r.query.AppConversationTemplateOnline
template, err := appConversationTemplateOnline.WithContext(ctx).Where(
appConversationTemplateOnline.TemplateID.Eq(ret.TemplateID),
- appConversationTemplateOnline.AppID.Eq(appID),
+ appConversationTemplateOnline.AppID.Eq(bizID),
).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -894,11 +894,11 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
return "", false, fmt.Errorf("unknown env %v", env)
}
-func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
+func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
if env == vo.Draft {
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
- appDynamicConversationDraft.AppID.Eq(appID),
+ appDynamicConversationDraft.AppID.Eq(bizID),
appDynamicConversationDraft.ConnectorID.Eq(connectorID),
appDynamicConversationDraft.ConversationID.Eq(conversationID),
).First()
@@ -918,7 +918,7 @@ func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.
} else if env == vo.Online {
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
- appDynamicConversationOnline.AppID.Eq(appID),
+ appDynamicConversationOnline.AppID.Eq(bizID),
appDynamicConversationOnline.ConnectorID.Eq(connectorID),
appDynamicConversationOnline.ConversationID.Eq(conversationID),
).First()
diff --git a/backend/domain/workflow/internal/schema/node_schema.go b/backend/domain/workflow/internal/schema/node_schema.go
index 61b85f0fc..e8c07f194 100644
--- a/backend/domain/workflow/internal/schema/node_schema.go
+++ b/backend/domain/workflow/internal/schema/node_schema.go
@@ -129,3 +129,8 @@ func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}
+
+type ChatHistoryAware interface {
+ ChatHistoryEnabled() bool
+ ChatHistoryRounds() int64
+}
diff --git a/backend/domain/workflow/internal/schema/workflow_schema.go b/backend/domain/workflow/internal/schema/workflow_schema.go
index 224008232..a2efc2cc4 100644
--- a/backend/domain/workflow/internal/schema/workflow_schema.go
+++ b/backend/domain/workflow/internal/schema/workflow_schema.go
@@ -38,6 +38,7 @@ type WorkflowSchema struct {
compositeNodes []*CompositeNode // won't serialize this
requireCheckPoint bool // won't serialize this
requireStreaming bool
+ historyRounds int64
once sync.Once
}
@@ -69,15 +70,22 @@ func (w *WorkflowSchema) Init() {
w.doGetCompositeNodes()
+ historyRounds := int64(0)
for _, node := range w.Nodes {
if node.Type == entity.NodeTypeSubWorkflow {
node.SubWorkflowSchema.Init()
+ historyRounds = max(historyRounds, node.SubWorkflowSchema.HistoryRounds())
if node.SubWorkflowSchema.requireCheckPoint {
w.requireCheckPoint = true
break
}
}
+ chatHistoryAware, ok := node.Configs.(ChatHistoryAware)
+ if ok && chatHistoryAware.ChatHistoryEnabled() {
+ historyRounds = max(historyRounds, chatHistoryAware.ChatHistoryRounds())
+ }
+
if rc, ok := node.Configs.(RequireCheckpoint); ok {
if rc.RequireCheckpoint() {
w.requireCheckPoint = true
@@ -86,6 +94,7 @@ func (w *WorkflowSchema) Init() {
}
}
+ w.historyRounds = historyRounds
w.requireStreaming = w.doRequireStreaming()
})
}
@@ -122,6 +131,12 @@ func (w *WorkflowSchema) RequireStreaming() bool {
return w.requireStreaming
}
+func (w *WorkflowSchema) HistoryRounds() int64 { return w.historyRounds }
+
+func (w *WorkflowSchema) SetHistoryRounds(historyRounds int64) {
+ w.historyRounds = historyRounds
+}
+
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
if w.Hierarchy == nil {
return nil
diff --git a/backend/domain/workflow/service/conversation_impl.go b/backend/domain/workflow/service/conversation_impl.go
index ccec12ea2..53e08600d 100644
--- a/backend/domain/workflow/service/conversation_impl.go
+++ b/backend/domain/workflow/service/conversation_impl.go
@@ -248,7 +248,7 @@ func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Con
if err != nil {
return false, err
}
- if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name {
+ if v.Name == vo.ConversationNameKey && v.DefaultValue == name {
return true, nil
}
}
@@ -296,7 +296,7 @@ func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context,
if err != nil {
return err
}
- if v.Name == "CONVERSATION_NAME" {
+ if v.Name == vo.ConversationNameKey {
v.DefaultValue = conversionName
}
startNode.Data.Outputs[idx] = v
@@ -351,18 +351,18 @@ func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
}
-func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
+func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) {
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
- AppID: ptr.Of(appID),
+ AppID: ptr.Of(bizID),
Name: ptr.Of(conversationName),
})
if err != nil {
return 0, 0, err
}
- conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) {
+ conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, bizID int64, userID, connectorID int64) (*conventity.Conversation, error) {
return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
- AgentID: appID,
+ AgentID: bizID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneWorkflow,
@@ -371,7 +371,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
if existed {
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
- AppID: appID,
+ BizID: bizID,
ConnectorID: connectorID,
UserID: userID,
TemplateID: t.TemplateID,
@@ -383,7 +383,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
}
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
- AppID: appID,
+ BizID: bizID,
ConnectorID: connectorID,
UserID: userID,
Name: conversationName,
@@ -465,8 +465,8 @@ func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env
return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name)
}
-func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
- sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID)
+func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
+ sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, bizID, connectorID, conversationID)
if err != nil {
return "", false, err
}
@@ -474,7 +474,7 @@ func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.E
return sc, true, nil
}
- dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID)
+ dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, bizID, connectorID, conversationID)
if err != nil {
return "", false, err
}
diff --git a/backend/domain/workflow/service/executable_impl.go b/backend/domain/workflow/service/executable_impl.go
index d0734d401..1d854b630 100644
--- a/backend/domain/workflow/service/executable_impl.go
+++ b/backend/domain/workflow/service/executable_impl.go
@@ -22,6 +22,8 @@ import (
"fmt"
"time"
+ "github.com/coze-dev/coze-studio/backend/types/consts"
+
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
@@ -282,6 +284,50 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
return executeID, nil
}
+func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error {
+ if historyRounds <= 0 {
+ return nil
+ }
+
+ if shouldFetchConversationByName {
+ var cID, sID, bizID int64
+ var err error
+ if config.AppID != nil {
+ bizID = *config.AppID
+ } else if config.AgentID != nil {
+ bizID = *config.AgentID
+ }
+ for k, v := range input {
+ if k == vo.ConversationNameKey {
+ cName, ok := v.(string)
+ if !ok {
+ return errors.New("CONVERSATION_NAME must be string")
+ }
+ cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName)
+ if err != nil {
+ return err
+ }
+ config.ConversationID = ptr.Of(cID)
+ config.SectionID = ptr.Of(sID)
+ }
+ }
+ }
+
+ messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds)
+ if err != nil {
+ logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
+ }
+
+ if len(messages) > 0 {
+ config.ConversationHistory = messages
+ }
+
+ if len(scMessages) > 0 {
+ config.ConversationHistorySchemaMessages = scMessages
+ }
+ return nil
+}
+
func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) {
var (
err error
@@ -308,30 +354,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
}
}
- historyRounds := int64(0)
- if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
-
- historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
- if err != nil {
- return 0, err
- }
- }
-
- if historyRounds > 0 {
- messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
- if err != nil {
- logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
- }
-
- if len(messages) > 0 {
- config.ConversationHistory = messages
- }
-
- if len(scMessages) > 0 {
- config.ConversationHistorySchemaMessages = scMessages
- }
-
- }
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
@@ -342,6 +364,17 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
}
+ historyRounds := int64(0)
+ if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
+ historyRounds = workflowSC.HistoryRounds()
+ }
+
+ if historyRounds > 0 {
+ if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil {
+ return 0, err
+ }
+ }
+
wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID)))
if err != nil {
return 0, fmt.Errorf("failed to create workflow: %w", err)
@@ -417,29 +450,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
}
}
- historyRounds := int64(0)
- if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
- historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
- if err != nil {
- return nil, err
- }
- }
-
- if historyRounds > 0 {
- messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
- if err != nil {
- logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
- }
-
- if len(messages) > 0 {
- config.ConversationHistory = messages
- }
-
- if len(scMessages) > 0 {
- config.ConversationHistorySchemaMessages = scMessages
- }
-
- }
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
@@ -450,6 +460,17 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
}
+ historyRounds := int64(0)
+ if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
+ historyRounds = workflowSC.HistoryRounds()
+ }
+
+ if historyRounds > 0 {
+ if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil {
+ return nil, err
+ }
+ }
+
var wfOpts []compose.WorkflowOption
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
@@ -997,20 +1018,6 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
return nil
}
-const maxHistoryRounds int64 = 30
-
-func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
- if wfEntity == nil {
- return 0, nil
- }
-
- maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
- if err != nil {
- return 0, err
- }
- return min(maxRounds, maxHistoryRounds), nil
-}
-
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
convID := config.ConversationID
agentID := config.AgentID
@@ -1027,11 +1034,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
return nil, nil, nil
}
- var resolvedAppID int64
+ var bizID int64
if appID != nil {
- resolvedAppID = *appID
+ bizID = *appID
} else if agentID != nil {
- resolvedAppID = *agentID
+ bizID = *agentID
} else {
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
return nil, nil, nil
@@ -1039,7 +1046,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
ConversationID: *convID,
- AppID: resolvedAppID,
+ BizID: bizID,
UserID: userID,
Rounds: historyRounds + 1,
SectionID: *sectionID,
@@ -1048,7 +1055,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
if err != nil {
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
- return nil, nil, nil
+ return nil, nil, err
}
if len(runIds) <= 1 {
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
@@ -1061,7 +1068,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
})
if err != nil {
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
- return nil, nil, nil
+ return nil, nil, err
}
return response.Messages, response.SchemaMessages, nil
diff --git a/backend/domain/workflow/service/executable_impl_test.go b/backend/domain/workflow/service/executable_impl_test.go
new file mode 100644
index 000000000..e3224d6e0
--- /dev/null
+++ b/backend/domain/workflow/service/executable_impl_test.go
@@ -0,0 +1,286 @@
+/*
+ * Copyright 2025 coze-dev Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/cloudwego/eino/schema"
+ "github.com/stretchr/testify/assert"
+ "go.uber.org/mock/gomock"
+
+ workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
+ crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
+ messagemock "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock"
+ "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
+ mock_workflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
+ "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
+)
+
+func TestImpl_handleHistory(t *testing.T) {
+ ctx := context.Background()
+ ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
+ defer ctrl.Finish()
+
+ // Setup for cross-domain service mock
+ mockMessage := messagemock.NewMockMessage(ctrl)
+ crossmessage.SetDefaultSVC(mockMessage)
+
+ tests := []struct {
+ name string
+ setupMock func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository)
+ config *workflowModel.ExecuteConfig
+ input map[string]any
+ historyRounds int64
+ shouldFetch bool
+ expectErr bool
+ expectedHistory []*crossmessage.WfMessage
+ expectedSchemaHistory []*schema.Message
+ }{
+ {
+ name: "historyRounds is zero",
+ historyRounds: 0,
+ shouldFetch: true,
+ config: &workflowModel.ExecuteConfig{},
+ setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
+ },
+ expectErr: false,
+ },
+ {
+ name: "shouldFetch is false",
+ historyRounds: 5,
+ shouldFetch: false,
+ config: &workflowModel.ExecuteConfig{
+ AppID: ptr.Of(int64(1)),
+ ConversationID: ptr.Of(int64(100)),
+ SectionID: ptr.Of(int64(101)),
+ },
+ setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
+ msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2}, nil).AnyTimes()
+ msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
+ Messages: []*crossmessage.WfMessage{{ID: 1}},
+ SchemaMessages: []*schema.Message{{
+ Role: schema.User,
+ Content: "123",
+ }},
+ }, nil).AnyTimes()
+ },
+ expectErr: false,
+ expectedHistory: []*crossmessage.WfMessage{{ID: 1}},
+ expectedSchemaHistory: []*schema.Message{{
+ Role: schema.User,
+ Content: "123",
+ }},
+ },
+ {
+ name: "fetch conversation by name - conversation exists",
+ historyRounds: 3,
+ shouldFetch: true,
+ config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
+ input: map[string]any{"CONVERSATION_NAME": "test-conv"},
+ setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
+ service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "test-conv").Return(int64(200), int64(201), nil).AnyTimes()
+ msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{3, 4}, nil).AnyTimes()
+ msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
+ Messages: []*crossmessage.WfMessage{{ID: 2}},
+ SchemaMessages: []*schema.Message{{
+ Role: schema.Assistant,
+ Content: "123",
+ }},
+ }, nil).AnyTimes()
+ repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
+ TemplateID: int64(202),
+ SpaceID: int64(203),
+ AppID: int64(204),
+ }, true, nil).AnyTimes()
+ repo.EXPECT().GetOrCreateStaticConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
+ },
+ expectErr: false,
+ expectedHistory: []*crossmessage.WfMessage{{ID: 2}},
+ expectedSchemaHistory: []*schema.Message{{
+ Role: schema.Assistant,
+ Content: "123",
+ }},
+ },
+ {
+ name: "fetch conversation by name - conversation not exists",
+ historyRounds: 3,
+ shouldFetch: true,
+ config: &workflowModel.ExecuteConfig{AgentID: ptr.Of(int64(2))},
+ input: map[string]any{"CONVERSATION_NAME": "new-conv"},
+ setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
+ service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "new-conv").Return(int64(300), int64(301), nil).AnyTimes()
+ msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{5, 6}, nil).AnyTimes()
+ msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
+ Messages: []*crossmessage.WfMessage{{ID: 3}},
+ }, nil).AnyTimes()
+ repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
+ TemplateID: int64(202),
+ SpaceID: int64(203),
+ AppID: int64(204),
+ }, false, nil).AnyTimes()
+ repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
+ },
+ expectErr: false,
+ expectedHistory: []*crossmessage.WfMessage{{ID: 3}},
+ },
+ {
+ name: "input with wrong type for conversation name",
+ historyRounds: 5,
+ shouldFetch: true,
+ config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
+ input: map[string]any{"CONVERSATION_NAME": 12345},
+ setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
+ },
+ expectErr: true,
+ },
+ {
+ name: "GetOrCreateConversation returns error",
+ historyRounds: 5,
+ shouldFetch: true,
+ config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
+ input: map[string]any{"CONVERSATION_NAME": "fail-conv"},
+ setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
+ service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "fail-conv").Return(int64(0), int64(0), errors.New("db error")).AnyTimes()
+ repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
+ TemplateID: int64(202),
+ SpaceID: int64(203),
+ AppID: int64(204),
+ }, false, nil).AnyTimes()
+ repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, errors.New("db error")).AnyTimes()
+ },
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockService := mock_workflow.NewMockService(ctrl)
+ mockRepo := mock_workflow.NewMockRepository(ctrl)
+ testImpl := &impl{repo: mockRepo, conversationImpl: &conversationImpl{repo: mockRepo}}
+
+ tt.setupMock(mockService, mockMessage, mockRepo)
+
+ err := testImpl.handleHistory(ctx, tt.config, tt.input, tt.historyRounds, tt.shouldFetch)
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ if tt.expectedHistory != nil {
+ assert.Equal(t, tt.expectedHistory, tt.config.ConversationHistory)
+ } else if tt.historyRounds == 0 {
+ assert.Nil(t, tt.config.ConversationHistory)
+ } else if tt.expectedSchemaHistory != nil {
+ assert.Equal(t, tt.expectedSchemaHistory, tt.config.ConversationHistorySchemaMessages)
+ }
+ }
+ })
+ }
+}
+
+func TestImpl_prefetchChatHistory(t *testing.T) {
+ ctx := context.Background()
+ ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
+ defer ctrl.Finish()
+
+ mockMessage := messagemock.NewMockMessage(ctrl)
+ crossmessage.SetDefaultSVC(mockMessage)
+
+ tests := []struct {
+ name string
+ setupMock func(msgSvc *messagemock.MockMessage)
+ config workflowModel.ExecuteConfig
+ historyRounds int64
+ expectErr bool
+ }{
+ {
+ name: "SectionID is nil",
+ config: workflowModel.ExecuteConfig{
+ ConversationID: ptr.Of(int64(100)),
+ AppID: ptr.Of(int64(1)),
+ },
+ historyRounds: 5,
+ setupMock: func(msgSvc *messagemock.MockMessage) {},
+ expectErr: false,
+ },
+ {
+ name: "ConversationID is nil",
+ config: workflowModel.ExecuteConfig{
+ SectionID: ptr.Of(int64(101)),
+ AppID: ptr.Of(int64(1)),
+ },
+ historyRounds: 5,
+ setupMock: func(msgSvc *messagemock.MockMessage) {},
+ expectErr: false,
+ },
+ {
+ name: "AppID and AgentID are both nil",
+ config: workflowModel.ExecuteConfig{
+ ConversationID: ptr.Of(int64(100)),
+ SectionID: ptr.Of(int64(101)),
+ },
+ historyRounds: 5,
+ setupMock: func(msgSvc *messagemock.MockMessage) {},
+ expectErr: false,
+ },
+ {
+ name: "GetLatestRunIDs returns error",
+ config: workflowModel.ExecuteConfig{
+ AppID: ptr.Of(int64(1)),
+ ConversationID: ptr.Of(int64(100)),
+ SectionID: ptr.Of(int64(101)),
+ },
+ historyRounds: 5,
+ setupMock: func(msgSvc *messagemock.MockMessage) {
+ msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
+ },
+ expectErr: true,
+ },
+ {
+ name: "GetMessagesByRunIDs returns error",
+ config: workflowModel.ExecuteConfig{
+ AppID: ptr.Of(int64(1)),
+ ConversationID: ptr.Of(int64(100)),
+ SectionID: ptr.Of(int64(101)),
+ },
+ historyRounds: 5,
+ setupMock: func(msgSvc *messagemock.MockMessage) {
+ msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2, 3}, nil)
+ msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
+ },
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testImpl := &impl{}
+ tt.setupMock(mockMessage)
+
+ _, _, err := testImpl.prefetchChatHistory(ctx, tt.config, tt.historyRounds)
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/backend/domain/workflow/service/service_impl.go b/backend/domain/workflow/service/service_impl.go
index a007064a2..50b3ab019 100644
--- a/backend/domain/workflow/service/service_impl.go
+++ b/backend/domain/workflow/service/service_impl.go
@@ -39,7 +39,6 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
- "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
@@ -522,7 +521,7 @@ func isEnableChatHistory(s *schema.NodeSchema) bool {
return false
}
- chatHistoryAware, ok := s.Configs.(nodes.ChatHistoryAware)
+ chatHistoryAware, ok := s.Configs.(schema.ChatHistoryAware)
if !ok {
return false
}
@@ -2171,15 +2170,15 @@ func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error {
vMap[v.Name] = true
}
- if _, ok := vMap["USER_INPUT"]; !ok {
+ if _, ok := vMap[vo.UserInputKey]; !ok {
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
- Name: "USER_INPUT",
+ Name: vo.UserInputKey,
Type: vo.VariableTypeString,
})
}
- if _, ok := vMap["CONVERSATION_NAME"]; !ok {
+ if _, ok := vMap[vo.ConversationNameKey]; !ok {
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
- Name: "CONVERSATION_NAME",
+ Name: vo.ConversationNameKey,
Type: vo.VariableTypeString,
DefaultValue: "Default",
})
diff --git a/backend/domain/workflow/service/utils.go b/backend/domain/workflow/service/utils.go
index 572f1320e..dd5608f8c 100644
--- a/backend/domain/workflow/service/utils.go
+++ b/backend/domain/workflow/service/utils.go
@@ -22,15 +22,11 @@ import (
"strconv"
"strings"
- workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
- wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
- "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
- "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@@ -201,214 +197,3 @@ func isIncremental(prev version, next version) bool {
return next.Patch > prev.Patch
}
-
-func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) {
- visited := make(map[string]struct{})
- maxRounds := int64(0)
- err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds)
- return maxRounds, err
-}
-
-func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
- visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
- if _, ok := visited[visitedKey]; ok {
- return nil
- }
- visited[visitedKey] = struct{}{}
-
- var canvas vo.Canvas
- if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil {
- return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err)
- }
-
- return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds)
-}
-
-func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
- for _, node := range nodes {
- if node == nil {
- continue
- }
-
- if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory {
- if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds {
- *maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound
- }
- } else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil {
- param := node.Data.Inputs.LLMParam
- bs, _ := sonic.Marshal(param)
- llmParam := make(vo.LLMParam, 0)
- if err := sonic.Unmarshal(bs, &llmParam); err != nil {
- return err
- }
- var chatHistoryEnabled bool
- var chatHistoryRound int64
- for _, param := range llmParam {
- switch param.Name {
- case "enableChatHistory":
- if val, ok := param.Input.Value.Content.(bool); ok {
- b := val
- chatHistoryEnabled = b
- }
- case "chatHistoryRound":
- if strVal, ok := param.Input.Value.Content.(string); ok {
- int64Val, err := strconv.ParseInt(strVal, 10, 64)
- if err != nil {
- return err
- }
- chatHistoryRound = int64Val
- }
- }
- }
-
- if chatHistoryEnabled {
- if chatHistoryRound > *maxRounds {
- *maxRounds = chatHistoryRound
- }
- }
- }
-
- isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil
- if isSubWorkflow {
- workflowIDStr := node.Data.Inputs.WorkflowID
- if workflowIDStr == "" {
- continue
- }
-
- workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
- if err != nil {
- return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
- }
-
- subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
- ID: workflowID,
- QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
- Version: node.Data.Inputs.WorkflowVersion,
- })
- if err != nil {
- return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
- }
-
- if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil {
- return err
- }
- }
-
- if len(node.Blocks) > 0 {
- if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) {
- if wfEntity == nil {
- return 0, nil
- }
- visited := make(map[string]struct{})
- visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
- if _, ok := visited[visitedKey]; ok {
- return 0, nil
- }
- visited[visitedKey] = struct{}{}
- maxRounds := int64(0)
- c := &vo.Canvas{}
- if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
- return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
- }
- var (
- n *vo.Node
- nodeFinder func(nodes []*vo.Node) *vo.Node
- )
- nodeFinder = func(nodes []*vo.Node) *vo.Node {
- for i := range nodes {
- if nodes[i].ID == nodeID {
- return nodes[i]
- }
- if len(nodes[i].Blocks) > 0 {
- if n := nodeFinder(nodes[i].Blocks); n != nil {
- return n
- }
- }
- }
- return nil
- }
-
- n = nodeFinder(c.Nodes)
- if n.Type == entity.NodeTypeLLM.IDStr() {
- if n.Data == nil || n.Data.Inputs == nil {
- return 0, nil
- }
- param := n.Data.Inputs.LLMParam
- bs, _ := sonic.Marshal(param)
- llmParam := make(vo.LLMParam, 0)
- if err := sonic.Unmarshal(bs, &llmParam); err != nil {
- return 0, err
- }
- var chatHistoryEnabled bool
- var chatHistoryRound int64
- for _, param := range llmParam {
- switch param.Name {
- case "enableChatHistory":
- if val, ok := param.Input.Value.Content.(bool); ok {
- b := val
- chatHistoryEnabled = b
- }
- case "chatHistoryRound":
- if strVal, ok := param.Input.Value.Content.(string); ok {
- int64Val, err := strconv.ParseInt(strVal, 10, 64)
- if err != nil {
- return 0, err
- }
- chatHistoryRound = int64Val
- }
- }
- }
- if chatHistoryEnabled {
- return chatHistoryRound, nil
- }
- return 0, nil
- }
-
- if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() {
- if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory {
- return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil
- }
- return 0, nil
- }
-
- if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
- if n.Data != nil && n.Data.Inputs != nil {
- workflowIDStr := n.Data.Inputs.WorkflowID
- if workflowIDStr == "" {
- return 0, nil
- }
- workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
- if err != nil {
- return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err)
- }
- subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
- ID: workflowID,
- QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
- Version: n.Data.Inputs.WorkflowVersion,
- })
- if err != nil {
- return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
- }
- if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil {
- return 0, err
- }
- return maxRounds, nil
- }
- }
-
- if len(n.Blocks) > 0 {
- if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil {
- return 0, err
- }
- }
- return maxRounds, nil
-}