Compare commits
5 Commits
v0.3.0
...
feat/addit
| Author | SHA1 | Date | |
|---|---|---|---|
| ca89a40eb4 | |||
| d84347da63 | |||
| 8ff79b7874 | |||
| 2d2250d51c | |||
| 2ceaad0c0f |
@ -228,7 +228,6 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
h.POST("/api/workflow_api/chat_flow_role/delete", DeleteChatFlowRole)
|
||||
h.POST("/api/workflow_api/chat_flow_role/create", CreateChatFlowRole)
|
||||
h.GET("/api/workflow_api/chat_flow_role/get", GetChatFlowRole)
|
||||
h.POST("/v1/workflows/chat", OpenAPIChatFlowRun)
|
||||
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
mockIDGen := mock.NewMockIDGenerator(ctrl)
|
||||
@ -1083,46 +1082,6 @@ func (r *wfTestRunner) openapiResume(id string, eventID string, resumeData strin
|
||||
return re
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) openapiChatFlowRun(wfID string, cID, appID, botID *string, input any, additionalMessage []*workflow.EnterMessage) *sse.Reader {
|
||||
inputStr, _ := sonic.MarshalString(input)
|
||||
|
||||
req := &workflow.ChatFlowRunRequest{
|
||||
WorkflowID: wfID,
|
||||
Parameters: ptr.Of(inputStr),
|
||||
AdditionalMessages: additionalMessage,
|
||||
}
|
||||
if cID != nil {
|
||||
req.ConversationID = cID
|
||||
}
|
||||
if appID != nil {
|
||||
req.AppID = appID
|
||||
}
|
||||
if botID != nil {
|
||||
req.BotID = botID
|
||||
}
|
||||
|
||||
m, err := sonic.Marshal(req)
|
||||
assert.NoError(r.t, err)
|
||||
|
||||
c, _ := client.NewClient()
|
||||
hReq, hResp := protocol.AcquireRequest(), protocol.AcquireResponse()
|
||||
hReq.SetRequestURI("http://localhost:8888" + "/v1/workflows/chat")
|
||||
hReq.SetMethod("POST")
|
||||
hReq.SetBody(m)
|
||||
hReq.SetHeader("Content-Type", "application/json")
|
||||
err = c.Do(context.Background(), hReq, hResp)
|
||||
assert.NoError(r.t, err)
|
||||
|
||||
if hResp.StatusCode() != http.StatusOK {
|
||||
r.t.Errorf("unexpected status code: %d, body: %s", hResp.StatusCode(), string(hResp.Body()))
|
||||
}
|
||||
|
||||
re, err := sse.NewReader(hResp)
|
||||
assert.NoError(r.t, err)
|
||||
|
||||
return re
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) runServer() func() {
|
||||
go func() {
|
||||
_ = r.h.Run()
|
||||
@ -2495,7 +2454,7 @@ func TestStartNodeDefaultValues(t *testing.T) {
|
||||
result, _ := r.openapiSyncRun(idStr, input)
|
||||
assert.Equal(t, result, map[string]any{
|
||||
"ts": "2025-07-09 21:43:34",
|
||||
"files": "http://imagex.fanlv.fun/tos-cn-i-1heqlfnr21/e81acc11277f421390770618e24e01ce.jpeg~tplv-1heqlfnr21-image.image",
|
||||
"files": "http://imagex.fanlv.fun/tos-cn-i-1heqlfnr21/e81acc11277f421390770618e24e01ce.jpeg~tplv-1heqlfnr21-image.image?x-wf-file_name=20250317-154742.jpeg",
|
||||
"str": "str",
|
||||
"object": map[string]any{
|
||||
"a": "1",
|
||||
@ -2519,7 +2478,7 @@ func TestStartNodeDefaultValues(t *testing.T) {
|
||||
result, _ := r.openapiSyncRun(idStr, input)
|
||||
assert.Equal(t, result, map[string]any{
|
||||
"ts": "2025-07-09 21:43:34",
|
||||
"files": "http://imagex.fanlv.fun/tos-cn-i-1heqlfnr21/e81acc11277f421390770618e24e01ce.jpeg~tplv-1heqlfnr21-image.image",
|
||||
"files": "http://imagex.fanlv.fun/tos-cn-i-1heqlfnr21/e81acc11277f421390770618e24e01ce.jpeg~tplv-1heqlfnr21-image.image?x-wf-file_name=20250317-154742.jpeg",
|
||||
"str": "str",
|
||||
"object": map[string]any{
|
||||
"a": "1",
|
||||
@ -2544,7 +2503,7 @@ func TestStartNodeDefaultValues(t *testing.T) {
|
||||
result, _ := r.openapiSyncRun(idStr, input)
|
||||
assert.Equal(t, result, map[string]any{
|
||||
"ts": "2025-07-09 21:43:34",
|
||||
"files": "http://imagex.fanlv.fun/tos-cn-i-1heqlfnr21/e81acc11277f421390770618e24e01ce.jpeg~tplv-1heqlfnr21-image.image",
|
||||
"files": "http://imagex.fanlv.fun/tos-cn-i-1heqlfnr21/e81acc11277f421390770618e24e01ce.jpeg~tplv-1heqlfnr21-image.image?x-wf-file_name=20250317-154742.jpeg",
|
||||
"str": "value",
|
||||
"object": map[string]any{
|
||||
"a": "1",
|
||||
@ -5532,7 +5491,7 @@ func TestConversationOfChatFlow(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
v.DefaultValue = cName
|
||||
}
|
||||
startNode.Data.Outputs[idx] = v
|
||||
@ -5563,7 +5522,7 @@ func TestConversationOfChatFlow(t *testing.T) {
|
||||
for _, vAny := range node.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
assert.NoError(t, err)
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
assert.Equal(t, v.DefaultValue, updateName)
|
||||
}
|
||||
}
|
||||
@ -5610,7 +5569,7 @@ func TestConversationOfChatFlow(t *testing.T) {
|
||||
for _, vAny := range node.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
assert.NoError(t, err)
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
assert.Equal(t, v.DefaultValue, cName+"copy")
|
||||
}
|
||||
}
|
||||
@ -6029,266 +5988,3 @@ func TestConversationHistoryNodes(t *testing.T) {
|
||||
assert.Equal(t, []any{}, outputMap["history_list"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkflowRunWithFiles(t *testing.T) {
|
||||
mockey.PatchConvey("workflow run with files", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
r.knowledge.EXPECT().Store(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error) {
|
||||
|
||||
assert.Equal(t, "北京旅游景点.txt", document.FileName)
|
||||
return &knowledge.CreateDocumentResponse{
|
||||
DocumentID: 1,
|
||||
FileURL: document.FileURL,
|
||||
FileName: document.FileName,
|
||||
}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
runner := mockcode.NewMockRunner(r.ctrl)
|
||||
runner.EXPECT().Run(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, request *coderunner.RunRequest) (*coderunner.RunResponse, error) {
|
||||
|
||||
return &coderunner.RunResponse{
|
||||
Result: request.Params,
|
||||
}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
mockey.Mock(code.GetCodeRunner).Return(runner).Build()
|
||||
|
||||
idStr := r.load("workflow_wf_file_name.json")
|
||||
r.publish(idStr, "v0.1.1", true)
|
||||
m, execID := r.openapiSyncRun(idStr, map[string]string{
|
||||
"f": "http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/27b01dd5-b0f5-4dbd-a075-a48c14162d23.txt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T074412Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=2f3051a0645c9ed260f7cb6c93954147ceb347a61366c9f70b98d43c299a7732&x-wf-file_name=%E5%8C%97%E4%BA%AC%E6%97%85%E6%B8%B8%E6%99%AF%E7%82%B9.txt",
|
||||
"fs": "[\"http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/85056c12-ea40-4588-a2a2-5eab56b94e4c.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T074404Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=59f3a7b6774a33de127e42878e4821635ce74e1fc29237ba03b13d67a068fedf&x-wf-file_name=%E5%BE%AE%E4%BF%A1%E5%9B%BE%E7%89%87_2025-07-02_154139_105.jpg\",\"http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/5ec9856d-0db0-44a1-9b82-43628221a928.jpeg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T074410Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=3887b0583084b0294b91e93e307c61ce3b910531d0e33a08e7c7d57de24c71ec&x-wf-file_name=20250317-154742.jpeg\"]",
|
||||
})
|
||||
assert.NotNil(t, execID)
|
||||
|
||||
assert.Equal(t, m["output"], []any{
|
||||
"http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/85056c12-ea40-4588-a2a2-5eab56b94e4c.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T074404Z&X-Amz-Expires=604800&X-Amz-Signature=59f3a7b6774a33de127e42878e4821635ce74e1fc29237ba03b13d67a068fedf&X-Amz-SignedHeaders=host",
|
||||
"http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/5ec9856d-0db0-44a1-9b82-43628221a928.jpeg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T074410Z&X-Amz-Expires=604800&X-Amz-Signature=3887b0583084b0294b91e93e307c61ce3b910531d0e33a08e7c7d57de24c71ec&X-Amz-SignedHeaders=host"})
|
||||
assert.Equal(t, m["filename"], "北京旅游景点.txt")
|
||||
fmt.Println(m, execID)
|
||||
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestChatFlowRun(t *testing.T) {
|
||||
mockey.PatchConvey("chat flow run", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
appworkflow.SVC.IDGenerator = r.idGen
|
||||
defer r.closeFn()
|
||||
defer r.runServer()()
|
||||
|
||||
chatModel1 := &testutil.UTChatModel{
|
||||
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
sr := schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "I ",
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "don't know.",
|
||||
},
|
||||
})
|
||||
return sr, nil
|
||||
},
|
||||
}
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel1, nil, nil).AnyTimes()
|
||||
|
||||
id := r.load("chatflow/llm_chat.json", withMode(workflow.WorkflowMode_ChatFlow))
|
||||
r.publish(id, "v0.0.1", true)
|
||||
cID := time.Now().UnixNano()
|
||||
cIDStr := strconv.FormatInt(cID, 10)
|
||||
appID := time.Now().UnixNano()
|
||||
appIDStr := strconv.FormatInt(appID, 10)
|
||||
|
||||
// Create conversation first
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
idStr := r.load("conversation_manager/update_dynamic_conversation.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
ret, _ := r.openapiSyncRun(idStr, map[string]string{
|
||||
"input": "v1",
|
||||
"new_name": "v2",
|
||||
}, withRunProjectID(appID))
|
||||
assert.Equal(t, map[string]any{"conversationId": strconv.FormatInt(cID, 10), "isExisted": false, "isSuccess": true}, ret["obj"])
|
||||
|
||||
msg := []*workflow.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Content: "你好",
|
||||
},
|
||||
}
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
|
||||
t.Run("chat flow run in app", func(t *testing.T) {
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run in bot", func(t *testing.T) {
|
||||
botID := time.Now().UnixNano()
|
||||
botIDStr := strconv.FormatInt(botID, 10)
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), nil, ptr.Of(botIDStr), map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run without cID", func(t *testing.T) {
|
||||
sseReader := r.openapiChatFlowRun(id, nil, ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run with additional messages", func(t *testing.T) {
|
||||
additionalMsg := []*workflow.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Content: "你好, 我叫小明",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ContentType: "text",
|
||||
Content: "你好小明, 很高兴认识你",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Content: "你好",
|
||||
},
|
||||
}
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, additionalMsg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run with history messages", func(t *testing.T) {
|
||||
id := r.load("chatflow/llm_chat_with_history.json", withMode(workflow.WorkflowMode_ChatFlow))
|
||||
r.publish(id, "v0.0.1", true)
|
||||
r.message.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{rID}, nil).AnyTimes()
|
||||
r.message.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&message0.GetMessagesByRunIDsResponse{
|
||||
Messages: []*message0.WfMessage{
|
||||
{
|
||||
ID: mID,
|
||||
Role: schema.User,
|
||||
Text: ptr.Of("你好"),
|
||||
},
|
||||
},
|
||||
}, nil).AnyTimes()
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("chat flow run with interrupt nodes ", func(t *testing.T) {
|
||||
// 生成一个携带 input, 问答文本 问答选项的三个中断节点 做测试
|
||||
id := r.load("chatflow/chat_run_with_interrupt.json", withMode(workflow.WorkflowMode_ChatFlow))
|
||||
r.publish(id, "v0.0.1", true)
|
||||
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, msg)
|
||||
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
if e.ID == "3" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "7383997384420262000")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, []*workflow.EnterMessage{
|
||||
{Role: string(schema.User), Content: "input:1", ContentType: "text"},
|
||||
})
|
||||
|
||||
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
if e.ID == "4" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "你好")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, []*workflow.EnterMessage{
|
||||
{Role: string(schema.User), Content: "hello", ContentType: "text"},
|
||||
})
|
||||
|
||||
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
if e.ID == "3" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "question_card_data", "请选择")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
|
||||
vo.ConversationNameKey: "Default",
|
||||
}, []*workflow.EnterMessage{
|
||||
{Role: string(schema.User), Content: "A", ContentType: "text"},
|
||||
})
|
||||
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
|
||||
if e.ID == "4" {
|
||||
assert.Equal(t, e.Type, "conversation.message.completed")
|
||||
assert.Contains(t, string(e.Data), "answer", "A")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@ -56,7 +56,6 @@ type ExecuteConfig struct {
|
||||
ConversationHistorySchemaMessages []*schema.Message
|
||||
SectionID *int64
|
||||
MaxHistoryRounds *int32
|
||||
InputFileFields map[string]*FileInfo
|
||||
}
|
||||
|
||||
type ExecuteMode string
|
||||
@ -92,9 +91,3 @@ const (
|
||||
BizTypeAgent BizType = "agent"
|
||||
BizTypeWorkflow BizType = "workflow"
|
||||
)
|
||||
|
||||
type FileInfo struct {
|
||||
FileURL string `json:"file_url"`
|
||||
FileName string `json:"file_name"`
|
||||
FileExtension string `json:"file_extension"`
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@ -102,7 +103,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *
|
||||
return nil, err
|
||||
}
|
||||
if conData == nil {
|
||||
return nil, errors.New("conversation data is nil")
|
||||
return nil, errorx.New(errno.ErrConversationNotFound)
|
||||
}
|
||||
conversationData = conData
|
||||
|
||||
@ -110,7 +111,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *
|
||||
}
|
||||
|
||||
if conversationData.CreatorID != userID {
|
||||
return nil, errors.New("conversation data not match")
|
||||
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg","user not match"))
|
||||
}
|
||||
|
||||
return conversationData, nil
|
||||
@ -138,26 +139,31 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
multiContent, contentType, err := a.buildMultiContent(ctx, ar)
|
||||
multiAdditionalMessages, err := a.parseAdditionalMessages(ctx, ar)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filterMultiAdditionalMessages, multiContent, contentType, err := a.parseQueryContent(ctx, multiAdditionalMessages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
displayContent := a.buildDisplayContent(ctx, ar)
|
||||
arm := &entity.AgentRunMeta{
|
||||
ConversationID: ptr.From(ar.ConversationID),
|
||||
AgentID: ar.BotID,
|
||||
Content: multiContent,
|
||||
DisplayContent: displayContent,
|
||||
SpaceID: spaceID,
|
||||
UserID: ar.User,
|
||||
SectionID: conversationData.SectionID,
|
||||
PreRetrieveTools: shortcutCMDData,
|
||||
IsDraft: false,
|
||||
ConnectorID: connectorID,
|
||||
ContentType: contentType,
|
||||
Ext: ar.ExtraParams,
|
||||
CustomVariables: ar.CustomVariables,
|
||||
CozeUID: conversationData.CreatorID,
|
||||
ConversationID: ptr.From(ar.ConversationID),
|
||||
AgentID: ar.BotID,
|
||||
Content: multiContent,
|
||||
DisplayContent: displayContent,
|
||||
SpaceID: spaceID,
|
||||
UserID: ar.User,
|
||||
SectionID: conversationData.SectionID,
|
||||
PreRetrieveTools: shortcutCMDData,
|
||||
IsDraft: false,
|
||||
ConnectorID: connectorID,
|
||||
ContentType: contentType,
|
||||
Ext: ar.ExtraParams,
|
||||
CustomVariables: ar.CustomVariables,
|
||||
CozeUID: conversationData.CreatorID,
|
||||
AdditionalMessages: filterMultiAdditionalMessages,
|
||||
}
|
||||
return arm, nil
|
||||
}
|
||||
@ -200,29 +206,68 @@ func (a *OpenapiAgentRunApplication) buildDisplayContent(_ context.Context, ar *
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *run.ChatV3Request) ([]*message.InputMetaData, message.ContentType, error) {
|
||||
var multiContents []*message.InputMetaData
|
||||
contentType := message.ContentTypeText
|
||||
func (a *OpenapiAgentRunApplication) parseQueryContent(ctx context.Context, multiAdditionalMessages []*entity.AdditionalMessage) ([]*entity.AdditionalMessage, []*message.InputMetaData, message.ContentType, error) {
|
||||
|
||||
var multiContent []*message.InputMetaData
|
||||
var contentType message.ContentType
|
||||
var filterMultiAdditionalMessages []*entity.AdditionalMessage
|
||||
filterMultiAdditionalMessages = multiAdditionalMessages
|
||||
|
||||
if len(multiAdditionalMessages) > 0 {
|
||||
lastMessage := multiAdditionalMessages[len(multiAdditionalMessages)-1]
|
||||
if lastMessage != nil && lastMessage.Role == schema.User {
|
||||
multiContent = lastMessage.Content
|
||||
contentType = lastMessage.ContentType
|
||||
filterMultiAdditionalMessages = multiAdditionalMessages[:len(multiAdditionalMessages)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return filterMultiAdditionalMessages, multiContent, contentType, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) parseAdditionalMessages(ctx context.Context, ar *run.ChatV3Request) ([]*entity.AdditionalMessage, error) {
|
||||
|
||||
additionalMessages := make([]*entity.AdditionalMessage, 0, len(ar.AdditionalMessages))
|
||||
|
||||
for _, item := range ar.AdditionalMessages {
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
if item.Role != string(schema.User) {
|
||||
return nil, contentType, errors.New("role not match")
|
||||
if item.Role != string(schema.User) && item.Role != string(schema.Assistant) {
|
||||
return nil, errors.New("additional message role only support user and assistant")
|
||||
}
|
||||
if item.Type != nil && !slices.Contains([]message.MessageType{message.MessageTypeQuestion, message.MessageTypeAnswer}, message.MessageType(*item.Type)) {
|
||||
return nil, errors.New("additional message type only support question and answer now")
|
||||
}
|
||||
|
||||
addOne := entity.AdditionalMessage{
|
||||
Role: schema.RoleType(item.Role),
|
||||
}
|
||||
if item.Type != nil {
|
||||
addOne.Type = message.MessageType(*item.Type)
|
||||
} else {
|
||||
addOne.Type = message.MessageTypeQuestion
|
||||
}
|
||||
|
||||
if item.ContentType == run.ContentTypeText {
|
||||
if item.Content == "" {
|
||||
continue
|
||||
}
|
||||
multiContents = append(multiContents, &message.InputMetaData{
|
||||
|
||||
addOne.ContentType = message.ContentTypeText
|
||||
addOne.Content = []*message.InputMetaData{{
|
||||
Type: message.InputTypeText,
|
||||
Text: item.Content,
|
||||
})
|
||||
}}
|
||||
}
|
||||
|
||||
if item.ContentType == run.ContentTypeMixApi {
|
||||
contentType = message.ContentTypeMix
|
||||
|
||||
if ptr.From(item.Type) == string(message.MessageTypeAnswer) {
|
||||
return nil, errors.New(" answer messages only support text content")
|
||||
}
|
||||
|
||||
addOne.ContentType = message.ContentTypeMix
|
||||
var inputs []*run.AdditionalContent
|
||||
err := json.Unmarshal([]byte(item.Content), &inputs)
|
||||
|
||||
@ -236,7 +281,8 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *
|
||||
}
|
||||
switch message.InputType(one.Type) {
|
||||
case message.InputTypeText:
|
||||
multiContents = append(multiContents, &message.InputMetaData{
|
||||
|
||||
addOne.Content = append(addOne.Content, &message.InputMetaData{
|
||||
Type: message.InputTypeText,
|
||||
Text: ptr.From(one.Text),
|
||||
})
|
||||
@ -250,12 +296,12 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *
|
||||
ID: one.GetFileID(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
return nil, err
|
||||
}
|
||||
fileUrl = fileInfo.File.Url
|
||||
fileURI = fileInfo.File.TosURI
|
||||
}
|
||||
multiContents = append(multiContents, &message.InputMetaData{
|
||||
addOne.Content = append(addOne.Content, &message.InputMetaData{
|
||||
Type: message.InputType(one.Type),
|
||||
FileData: []*message.FileData{
|
||||
{
|
||||
@ -269,10 +315,10 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
additionalMessages = append(additionalMessages, &addOne)
|
||||
}
|
||||
|
||||
return multiContents, contentType, nil
|
||||
return additionalMessages, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, streamer *schema.StreamReader[*entity.AgentRunResponse]) {
|
||||
|
||||
903
backend/application/conversation/openapi_agent_run_test.go
Normal file
@ -0,0 +1,903 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
openapiEntity "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
|
||||
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
uploadEntity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
|
||||
uploadService "github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
||||
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
|
||||
mockSingleAgent "github.com/coze-dev/coze-studio/backend/internal/mock/domain/agent/singleagent"
|
||||
mockAgentRun "github.com/coze-dev/coze-studio/backend/internal/mock/domain/conversation/agentrun"
|
||||
mockConversation "github.com/coze-dev/coze-studio/backend/internal/mock/domain/conversation/conversation"
|
||||
mockShortcut "github.com/coze-dev/coze-studio/backend/internal/mock/domain/shortcutcmd"
|
||||
mockUpload "github.com/coze-dev/coze-studio/backend/internal/mock/domain/upload"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func setupMocks(t *testing.T) (*OpenapiAgentRunApplication, *mockShortcut.MockShortcutCmd, *mockUpload.MockUploadService, *mockAgentRun.MockRun, *mockConversation.MockConversation, *mockSingleAgent.MockSingleAgent, *gomock.Controller) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
mockShortcutSvc := mockShortcut.NewMockShortcutCmd(ctrl)
|
||||
mockUploadSvc := mockUpload.NewMockUploadService(ctrl)
|
||||
mockAgentRunSvc := mockAgentRun.NewMockRun(ctrl)
|
||||
mockConversationSvc := mockConversation.NewMockConversation(ctrl)
|
||||
mockSingleAgentSvc := mockSingleAgent.NewMockSingleAgent(ctrl)
|
||||
|
||||
app := &OpenapiAgentRunApplication{
|
||||
ShortcutDomainSVC: mockShortcutSvc,
|
||||
UploaodDomainSVC: mockUploadSvc,
|
||||
}
|
||||
|
||||
// Setup ConversationSVC mocks
|
||||
originalConversationSVC := ConversationSVC
|
||||
ConversationSVC = &ConversationApplicationService{
|
||||
AgentRunDomainSVC: mockAgentRunSvc,
|
||||
ConversationDomainSVC: mockConversationSvc,
|
||||
appContext: &ServiceComponents{
|
||||
SingleAgentDomainSVC: mockSingleAgentSvc,
|
||||
},
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ConversationSVC = originalConversationSVC
|
||||
ctrl.Finish()
|
||||
})
|
||||
|
||||
return app, mockShortcutSvc, mockUploadSvc, mockAgentRunSvc, mockConversationSvc, mockSingleAgentSvc, ctrl
|
||||
}
|
||||
|
||||
func createTestContext() context.Context {
|
||||
ctx := context.Background()
|
||||
ctx = ctxcache.Init(ctx)
|
||||
apiKey := &openapiEntity.ApiKey{
|
||||
UserID: 12345,
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
}
|
||||
ctxcache.Store(ctx, consts.OpenapiAuthKeyInCtx, apiKey)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func createTestRequest() *run.ChatV3Request {
|
||||
return &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello, world!",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTestRequestWithMultipleMessages() *run.ChatV3Request {
|
||||
return &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello, I need help with something.",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Sure, I'd be happy to help! What do you need assistance with?",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: `{"type": "image", "url": "https://example.com/image.jpg"}`,
|
||||
ContentType: run.ContentTypeImage,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: `{"type": "file", "name": "document.pdf", "url": "https://example.com/doc.pdf"}`,
|
||||
ContentType: run.ContentTypeFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTestRequestWithAssistantOnly() *run.ChatV3Request {
|
||||
return &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'm here to help you with any questions you might have.",
|
||||
ContentType: run.ContentTypeText, // assistant role only supports text content type
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_Success(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
|
||||
// Mock agent check
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_CheckAgentError(t *testing.T) {
|
||||
app, _, _, _, _, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
|
||||
// Mock agent check failure
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(nil, errors.New("agent not found"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "agent not found")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_AgentNotExists(t *testing.T) {
|
||||
app, _, _, _, _, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
|
||||
// Mock agent check returns nil (agent not exists)
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(nil, nil)
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_CheckConversationError(t *testing.T) {
|
||||
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check failure
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(nil, errors.New("conversation not found"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "conversation not found")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ConversationPermissionError(t *testing.T) {
|
||||
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation with different creator
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 99999, // Different from user ID (12345)
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_CreateNewConversation(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
req.ConversationID = ptr.Of(int64(0)) // No conversation ID
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock create new conversation
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 22222,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().Create(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, meta *convEntity.CreateMeta) (*convEntity.Conversation, error) {
|
||||
assert.Equal(t, int64(67890), meta.AgentID)
|
||||
assert.Equal(t, int64(12345), meta.UserID)
|
||||
assert.Equal(t, common.Scene_SceneOpenApi, meta.Scene)
|
||||
return mockConv, nil
|
||||
})
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, int64(22222), *req.ConversationID) // Should be updated
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_AgentRunError(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("agent run failed"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "agent run failed")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_WithShortcutCommand(t *testing.T) {
|
||||
app, mockShortcut, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequest()
|
||||
req.ShortcutCommand = &run.ShortcutCommandDetail{
|
||||
CommandID: 123,
|
||||
Parameters: map[string]string{"param1": "value1"},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock shortcut command
|
||||
mockCmd := &cmdEntity.ShortcutCmd{
|
||||
ID: 123,
|
||||
PluginID: 456,
|
||||
PluginToolName: "test-tool",
|
||||
PluginToolID: 789,
|
||||
ToolType: 1,
|
||||
}
|
||||
mockShortcut.EXPECT().GetByCmdID(ctx, int64(123), int32(0)).Return(mockCmd, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_WithMultipleMessages(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequestWithMultipleMessages()
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
|
||||
// Verify that the request contains multiple messages with different roles and content types
|
||||
assert.Len(t, req.AdditionalMessages, 4)
|
||||
assert.Equal(t, "user", req.AdditionalMessages[0].Role)
|
||||
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[0].ContentType)
|
||||
assert.Equal(t, "assistant", req.AdditionalMessages[1].Role)
|
||||
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[1].ContentType)
|
||||
assert.Equal(t, "user", req.AdditionalMessages[2].Role)
|
||||
assert.Equal(t, run.ContentTypeImage, req.AdditionalMessages[2].ContentType)
|
||||
assert.Equal(t, "user", req.AdditionalMessages[3].Role)
|
||||
assert.Equal(t, run.ContentTypeFile, req.AdditionalMessages[3].ContentType)
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_WithAssistantMessage(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
req := createTestRequestWithAssistantOnly()
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
|
||||
// Verify that the assistant message only supports text content type
|
||||
assert.Len(t, req.AdditionalMessages, 1)
|
||||
assert.Equal(t, "assistant", req.AdditionalMessages[0].Role)
|
||||
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[0].ContentType)
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_WithMixedContentTypes(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with various content types for user role
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Here's a text message",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: `{"type": "audio", "url": "https://example.com/audio.mp3"}`,
|
||||
ContentType: run.ContentTypeAudio,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: `{"type": "video", "url": "https://example.com/video.mp4"}`,
|
||||
ContentType: run.ContentTypeVideo,
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I can only respond with text content.",
|
||||
ContentType: run.ContentTypeText, // assistant must use text
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: `{"type": "link", "url": "https://example.com"}`,
|
||||
ContentType: run.ContentTypeLink,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
|
||||
// Verify various content types are preserved
|
||||
assert.Len(t, req.AdditionalMessages, 5)
|
||||
|
||||
// Check user messages with different content types
|
||||
assert.Equal(t, "user", req.AdditionalMessages[0].Role)
|
||||
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[0].ContentType)
|
||||
|
||||
assert.Equal(t, "user", req.AdditionalMessages[1].Role)
|
||||
assert.Equal(t, run.ContentTypeAudio, req.AdditionalMessages[1].ContentType)
|
||||
|
||||
assert.Equal(t, "user", req.AdditionalMessages[2].Role)
|
||||
assert.Equal(t, run.ContentTypeVideo, req.AdditionalMessages[2].ContentType)
|
||||
|
||||
// Check assistant message (must be text)
|
||||
assert.Equal(t, "assistant", req.AdditionalMessages[3].Role)
|
||||
assert.Equal(t, run.ContentTypeText, req.AdditionalMessages[3].ContentType)
|
||||
|
||||
assert.Equal(t, "user", req.AdditionalMessages[4].Role)
|
||||
assert.Equal(t, run.ContentTypeLink, req.AdditionalMessages[4].ContentType)
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_InvalidRole(t *testing.T) {
|
||||
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with invalid role
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "system", // Invalid role
|
||||
Content: "System message",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success to reach parseAdditionalMessages
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "additional message role only support user and assistant")
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_InvalidType(t *testing.T) {
|
||||
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with invalid message type
|
||||
invalidType := "invalid_type"
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Test message",
|
||||
ContentType: run.ContentTypeText,
|
||||
Type: &invalidType, // Invalid type
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success to reach parseAdditionalMessages
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "additional message type only support question and answer now")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_AnswerWithNonTextContent(t *testing.T) {
|
||||
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with answer type but non-text content
|
||||
answerType := "answer"
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: `[{"type": "image", "file_url": "https://example.com/image.jpg"}]`,
|
||||
ContentType: run.ContentTypeMixApi, // object_string
|
||||
Type: &answerType,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success to reach parseAdditionalMessages
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "answer messages only support text content")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_MixApiWithFileURL(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with object_string content type and file URL
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: `[{"type": "text", "text": "Here's an image:"}, {"type": "image", "file_url": "https://example.com/image.jpg"}]`,
|
||||
ContentType: run.ContentTypeMixApi,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_MixApiWithFileID(t *testing.T) {
|
||||
app, _, mockUpload, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with object_string content type and file ID
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: `[{"type": "file", "file_id": "12345"}]`,
|
||||
ContentType: run.ContentTypeMixApi,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock upload service to return file info
|
||||
mockUpload.EXPECT().GetFile(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, req *uploadService.GetFileRequest) (*uploadService.GetFileResponse, error) {
|
||||
assert.Equal(t, int64(12345), req.ID)
|
||||
return &uploadService.GetFileResponse{
|
||||
File: &uploadEntity.File{
|
||||
Url: "https://example.com/file.pdf",
|
||||
TosURI: "tos://bucket/file.pdf",
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_FileIDError(t *testing.T) {
|
||||
app, _, mockUpload, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with object_string content type and file ID that will fail
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: `[{"type": "file", "file_id": "99999"}]`,
|
||||
ContentType: run.ContentTypeMixApi,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock upload service to return error
|
||||
mockUpload.EXPECT().GetFile(ctx, gomock.Any()).Return(nil, errors.New("file not found"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_EmptyContent(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with empty text content (should be skipped)
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "", // Empty content
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Valid content",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
|
||||
// Verify that only the non-empty message is included
|
||||
assert.Len(t, req.AdditionalMessages, 2) // Original request still has 2 messages
|
||||
}
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_NilMessage(t *testing.T) {
|
||||
app, _, _, mockAgentRun, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
// Create request with empty content message (should be skipped)
|
||||
req := &run.ChatV3Request{
|
||||
BotID: 67890,
|
||||
ConversationID: ptr.Of(int64(11111)),
|
||||
User: "test-user",
|
||||
AdditionalMessages: []*run.EnterMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "", // Empty content message
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Valid content",
|
||||
ContentType: run.ContentTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock agent check success
|
||||
mockAgent := &saEntity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
AgentID: 67890,
|
||||
SpaceID: 54321,
|
||||
},
|
||||
}
|
||||
mockSingleAgent.EXPECT().ObtainAgentByIdentity(ctx, gomock.Any()).Return(mockAgent, nil)
|
||||
|
||||
// Mock conversation check success
|
||||
mockConv := &convEntity.Conversation{
|
||||
ID: 11111,
|
||||
CreatorID: 12345,
|
||||
SectionID: 98765,
|
||||
}
|
||||
mockConversation.EXPECT().GetByID(ctx, int64(11111)).Return(mockConv, nil)
|
||||
|
||||
// Mock agent run failure to avoid pullStream complexity
|
||||
mockAgentRun.EXPECT().AgentRun(ctx, gomock.Any()).Return(nil, errors.New("mock stream error"))
|
||||
|
||||
err := app.OpenapiAgentRun(ctx, &sseImpl.SSenderImpl{}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mock stream error")
|
||||
}
|
||||
@ -209,7 +209,7 @@ type publishFn func(ctx context.Context, appContext *ServiceComponents, publishI
|
||||
|
||||
func publishAgentVariables(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
draftAgent := agent
|
||||
if draftAgent.VariablesMetaID == nil || *draftAgent.VariablesMetaID == 0 {
|
||||
if draftAgent.VariablesMetaID != nil || *draftAgent.VariablesMetaID == 0 {
|
||||
return draftAgent, nil
|
||||
}
|
||||
|
||||
|
||||
@ -221,7 +221,7 @@ const (
|
||||
"id": "5fJt3qKpSz",
|
||||
"name": "list",
|
||||
"defaultValue": [
|
||||
|
||||
|
||||
]
|
||||
}
|
||||
},
|
||||
@ -504,7 +504,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
workflowID = mustParseInt64(req.GetWorkflowID())
|
||||
isDebug = req.GetExecuteMode() == "DEBUG"
|
||||
appID, agentID *int64
|
||||
bizID int64
|
||||
resolveAppID int64
|
||||
conversationID int64
|
||||
sectionID int64
|
||||
version string
|
||||
@ -521,11 +521,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
if req.IsSetAppID() {
|
||||
appID = ptr.Of(mustParseInt64(req.GetAppID()))
|
||||
bizID = mustParseInt64(req.GetAppID())
|
||||
resolveAppID = mustParseInt64(req.GetAppID())
|
||||
}
|
||||
if req.IsSetBotID() {
|
||||
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
|
||||
bizID = mustParseInt64(req.GetBotID())
|
||||
resolveAppID = mustParseInt64(req.GetBotID())
|
||||
}
|
||||
|
||||
if appID != nil && agentID != nil {
|
||||
@ -564,16 +564,16 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
sectionID = cInfo.SectionID
|
||||
|
||||
// only trust the conversation name under the app
|
||||
conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, conversationID)
|
||||
conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existed {
|
||||
return nil, fmt.Errorf("conversation not found")
|
||||
}
|
||||
parameters[vo.ConversationNameKey] = conversationName
|
||||
parameters["CONVERSATION_NAME"] = conversationName
|
||||
} else if req.IsSetConversationID() && req.IsSetBotID() {
|
||||
parameters[vo.ConversationNameKey] = "Default"
|
||||
parameters["CONVERSATION_NAME"] = "Default"
|
||||
conversationID = mustParseInt64(req.GetConversationID())
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
@ -581,11 +581,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
}
|
||||
sectionID = cInfo.SectionID
|
||||
} else {
|
||||
conversationName, ok := parameters[vo.ConversationNameKey].(string)
|
||||
conversationName, ok := parameters["CONVERSATION_NAME"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("conversation name is requried")
|
||||
}
|
||||
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, userID, conversationName)
|
||||
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -594,7 +594,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
}
|
||||
|
||||
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: bizID,
|
||||
AgentID: resolveAppID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -606,7 +606,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
roundID := runRecord.ID
|
||||
|
||||
userMessage, err := toConversationMessage(ctx, bizID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
|
||||
userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -648,7 +648,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
return nil, err
|
||||
}
|
||||
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
|
||||
bizID: bizID,
|
||||
appID: resolveAppID,
|
||||
conversationID: conversationID,
|
||||
roundID: roundID,
|
||||
workflowID: workflowID,
|
||||
@ -684,7 +684,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
Cancellable: isDebug,
|
||||
}
|
||||
|
||||
historyMessages, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
|
||||
historyMessages, err := makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -706,7 +706,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
logs.CtxWarnf(ctx, "create history message failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
parameters[vo.UserInputKey], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
|
||||
parameters["USER_INPUT"], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -717,7 +717,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
}
|
||||
|
||||
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
|
||||
bizID: bizID,
|
||||
appID: resolveAppID,
|
||||
conversationID: conversationID,
|
||||
roundID: roundID,
|
||||
workflowID: workflowID,
|
||||
@ -731,7 +731,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlowInfo) func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) {
|
||||
var (
|
||||
bizID = info.bizID
|
||||
appID = info.appID
|
||||
conversationID = info.conversationID
|
||||
roundID = info.roundID
|
||||
workflowID = info.workflowID
|
||||
@ -798,7 +798,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
Role: string(schema.Assistant),
|
||||
Type: "follow_up",
|
||||
ContentType: "text",
|
||||
@ -815,7 +815,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
Status: vo.Completed,
|
||||
ExecuteID: strconv.FormatInt(executeID, 10),
|
||||
Usage: &vo.Usage{
|
||||
@ -929,7 +929,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
}
|
||||
|
||||
_, err = crossmessage.DefaultSVC().Create(ctx, &message.Message{
|
||||
AgentID: bizID,
|
||||
AgentID: appID,
|
||||
RunID: roundID,
|
||||
SectionID: sectionID,
|
||||
Content: msgContent,
|
||||
@ -947,7 +947,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
Role: string(schema.Assistant),
|
||||
Type: string(entity.Answer),
|
||||
ContentType: string(contentType),
|
||||
@ -1046,7 +1046,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
}
|
||||
intermediateMessage = &message.Message{
|
||||
ID: id,
|
||||
AgentID: bizID,
|
||||
AgentID: appID,
|
||||
RunID: roundID,
|
||||
SectionID: sectionID,
|
||||
ConversationID: conversationID,
|
||||
@ -1066,7 +1066,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
Role: string(dataMessage.Role),
|
||||
Type: string(dataMessage.Type),
|
||||
ContentType: string(message.ContentTypeText),
|
||||
@ -1092,7 +1092,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
|
||||
ChatID: strconv.FormatInt(roundID, 10),
|
||||
ConversationID: strconv.FormatInt(conversationID, 10),
|
||||
SectionID: strconv.FormatInt(sectionID, 10),
|
||||
BotID: strconv.FormatInt(bizID, 10),
|
||||
BotID: strconv.FormatInt(appID, 10),
|
||||
Role: string(dataMessage.Role),
|
||||
Type: string(dataMessage.Type),
|
||||
ContentType: string(message.ContentTypeText),
|
||||
@ -1155,9 +1155,9 @@ func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message
|
||||
} else {
|
||||
return "", fmt.Errorf("invalid message ccontent type %v", message.ContentType)
|
||||
}
|
||||
}
|
||||
|
||||
func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
||||
}
|
||||
func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
||||
|
||||
var (
|
||||
rID int64
|
||||
@ -1170,7 +1170,7 @@ func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, use
|
||||
for _, msg := range messages {
|
||||
if msg.Role == userRole {
|
||||
runRecord, err = crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: bizID,
|
||||
AgentID: appID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -1180,15 +1180,13 @@ func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, use
|
||||
return nil, err
|
||||
}
|
||||
rID = runRecord.ID
|
||||
} else if msg.Role == assistantRole {
|
||||
if rID == 0 {
|
||||
continue
|
||||
}
|
||||
} else if msg.Role == assistantRole && rID == 0 {
|
||||
continue
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid role type %v", msg.Role)
|
||||
}
|
||||
|
||||
m, err := toConversationMessage(ctx, bizID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
|
||||
m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1276,7 +1274,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
||||
func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
||||
type content struct {
|
||||
Type string `json:"type"`
|
||||
FileID *string `json:"file_id"`
|
||||
@ -1286,7 +1284,7 @@ func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sec
|
||||
return &message.Message{
|
||||
Role: schema.User,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
AgentID: appID,
|
||||
RunID: roundID,
|
||||
Content: msg.Content,
|
||||
ContentType: message.ContentTypeText,
|
||||
@ -1306,7 +1304,7 @@ func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sec
|
||||
Role: schema.User,
|
||||
MessageType: messageType,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
AgentID: appID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
RunID: roundID,
|
||||
ContentType: message.ContentTypeMix,
|
||||
@ -1434,7 +1432,7 @@ func toSchemaMessage(ctx context.Context, msg *workflow.EnterMessage) (*schema.M
|
||||
|
||||
type convertToChatFlowInfo struct {
|
||||
userMessage *schema.Message
|
||||
bizID int64
|
||||
appID int64
|
||||
conversationID int64
|
||||
roundID int64
|
||||
workflowID int64
|
||||
|
||||
@ -1,604 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
messageentity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun/agentrunmock"
|
||||
crossupload "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload/uploadmock"
|
||||
agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
uploadentity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
||||
)
|
||||
|
||||
func TestApplicationService_makeChatFlowUserInput(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *workflow.EnterMessage
|
||||
setupMock func()
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "content type text",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "text",
|
||||
Content: "hello",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: "hello",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "hello world"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: "hello world",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with file",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/file"},
|
||||
}, nil)
|
||||
},
|
||||
expected: "https://example.com/file",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text and file",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "see this file"}, {"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/file"},
|
||||
}, nil)
|
||||
},
|
||||
expected: "see this file,https://example.com/file",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "get file error",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: nil,
|
||||
}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "invalid",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
message: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `invalid-json`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
w := &ApplicationService{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := w.makeChatFlowUserInput(ctx, tt.message)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toConversationMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
bizID, cid, userID, roundID, sectionID := int64(2), int64(1), int64(4), int64(3), int64(5)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *workflow.EnterMessage
|
||||
messageType messageentity.MessageType
|
||||
setupMock func()
|
||||
expected *messageentity.Message
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "content type text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "text",
|
||||
Content: "hello",
|
||||
},
|
||||
messageType: messageentity.MessageTypeQuestion,
|
||||
setupMock: func() {},
|
||||
expected: &messageentity.Message{
|
||||
Role: schema.User,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
RunID: roundID,
|
||||
Content: "hello",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "hello"}]`,
|
||||
},
|
||||
messageType: messageentity.MessageTypeQuestion,
|
||||
setupMock: func() {},
|
||||
expected: &messageentity.Message{
|
||||
Role: schema.User,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
RunID: roundID,
|
||||
ContentType: messageentity.ContentTypeMix,
|
||||
MultiContent: []*messageentity.InputMetaData{
|
||||
{Type: messageentity.InputTypeText, Text: "hello"},
|
||||
},
|
||||
SectionID: sectionID,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with file",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
messageType: messageentity.MessageTypeQuestion,
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/file", TosURI: "tos://uri", Name: "file.txt"},
|
||||
}, nil)
|
||||
},
|
||||
expected: &messageentity.Message{
|
||||
Role: schema.User,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
ConversationID: cid,
|
||||
AgentID: bizID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
RunID: roundID,
|
||||
ContentType: messageentity.ContentTypeMix,
|
||||
MultiContent: []*messageentity.InputMetaData{
|
||||
{
|
||||
Type: "file",
|
||||
FileData: []*messageentity.FileData{
|
||||
{Url: "https://example.com/file", URI: "tos://uri", Name: "file.txt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
SectionID: sectionID,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "get file error",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "invalid",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: "invalid-json",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid input type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "invalid"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := toConversationMessage(ctx, bizID, cid, userID, roundID, sectionID, tt.messageType, tt.msg)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toSchemaMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *workflow.EnterMessage
|
||||
setupMock func()
|
||||
expected *schema.Message
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "content type text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "text",
|
||||
Content: "hello",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "hello",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with text",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "text", "text": "hello"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{Type: schema.ChatMessagePartTypeText, Text: "hello"},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with image",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "image", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
|
||||
File: &uploadentity.File{Url: "https://example.com/image.png"},
|
||||
}, nil)
|
||||
},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{URL: "https://example.com/image.png"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "content type object_string with various file types",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "1"}, {"type": "audio", "file_id": "2"}, {"type": "video", "file_id": "3"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 1}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/file"}}, nil)
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 2}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/audio"}}, nil)
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 3}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/video"}}, nil)
|
||||
},
|
||||
expected: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{URL: "https://example.com/file"}},
|
||||
{Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{URL: "https://example.com/audio"}},
|
||||
{Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: "https://example.com/video"}},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "get file error",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "file", "file_id": "123"}]`,
|
||||
},
|
||||
setupMock: func() {
|
||||
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "invalid",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: "invalid-json",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid input type",
|
||||
msg: &workflow.EnterMessage{
|
||||
ContentType: "object_string",
|
||||
Content: `[{"type": "invalid"}]`,
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := toSchemaMessage(ctx, tt.msg)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_makeChatFlowHistoryMessages(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockAgentRun := agentrunmock.NewMockAgentRun(ctrl)
|
||||
crossagentrun.SetDefaultSVC(mockAgentRun)
|
||||
mockUpload := uploadmock.NewMockUploader(ctrl)
|
||||
crossupload.SetDefaultSVC(mockUpload)
|
||||
|
||||
bizID, conversationID, userID, sectionID, connectorID := int64(2), int64(1), int64(3), int64(4), int64(5)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []*workflow.EnterMessage
|
||||
setupMock func()
|
||||
expected []*messageentity.Message
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty messages",
|
||||
messages: []*workflow.EnterMessage{},
|
||||
setupMock: func() {},
|
||||
expected: []*messageentity.Message{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "one user message",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "text", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
|
||||
},
|
||||
expected: []*messageentity.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
ConversationID: conversationID,
|
||||
AgentID: bizID,
|
||||
RunID: 100,
|
||||
Content: "hello",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "user and assistant message",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "text", Content: "hello"},
|
||||
{Role: "assistant", ContentType: "text", Content: "hi"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
|
||||
},
|
||||
expected: []*messageentity.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
ConversationID: conversationID,
|
||||
AgentID: bizID,
|
||||
RunID: 100,
|
||||
Content: "hello",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeQuestion,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
{
|
||||
Role: schema.User,
|
||||
ConversationID: conversationID,
|
||||
AgentID: bizID,
|
||||
RunID: 100,
|
||||
Content: "hi",
|
||||
ContentType: messageentity.ContentTypeText,
|
||||
MessageType: messageentity.MessageTypeAnswer,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "only assistant message",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "assistant", ContentType: "text", Content: "hi"},
|
||||
},
|
||||
setupMock: func() {},
|
||||
expected: []*messageentity.Message{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "create run record error",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "text", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid role",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "system", ContentType: "text", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "toConversationMessage error",
|
||||
messages: []*workflow.EnterMessage{
|
||||
{Role: "user", ContentType: "invalid", Content: "hello"},
|
||||
},
|
||||
setupMock: func() {
|
||||
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMock()
|
||||
result, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, tt.messages)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -108,10 +108,5 @@ func InitService(_ context.Context, components *ServiceComponents) (*Application
|
||||
SVC.TosClient = components.Tos
|
||||
SVC.IDGenerator = components.IDGen
|
||||
|
||||
err = SVC.InitNodeIconURLCache(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return SVC, nil
|
||||
}
|
||||
|
||||
@ -23,12 +23,10 @@ import (
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
@ -77,56 +75,12 @@ type ApplicationService struct {
|
||||
IDGenerator idgen.IDGenerator
|
||||
}
|
||||
|
||||
var (
|
||||
SVC = &ApplicationService{}
|
||||
nodeIconURLCache = make(map[string]string)
|
||||
nodeIconURLCacheMu sync.Mutex
|
||||
)
|
||||
var SVC = &ApplicationService{}
|
||||
|
||||
func GetWorkflowDomainSVC() domainWorkflow.Service {
|
||||
return SVC.DomainSVC
|
||||
}
|
||||
|
||||
func (w *ApplicationService) InitNodeIconURLCache(ctx context.Context) error {
|
||||
category2NodeMetaList, _, err := GetWorkflowDomainSVC().ListNodeMeta(ctx, nil)
|
||||
if err != nil {
|
||||
logs.Errorf("failed to list node meta for icon url cache: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
eg, gCtx := errgroup.WithContext(ctx)
|
||||
for _, nodeMetaList := range category2NodeMetaList {
|
||||
for _, nodeMeta := range nodeMetaList {
|
||||
eg.Go(func() error {
|
||||
if len(nodeMeta.IconURI) == 0 {
|
||||
// For custom nodes, if IconURI is not set, there will be no icon.
|
||||
logs.Warnf("node '%s' has an empty IconURI, it will have no icon", nodeMeta.Name)
|
||||
return nil
|
||||
}
|
||||
url, err := w.TosClient.GetObjectUrl(gCtx, nodeMeta.IconURI)
|
||||
if err != nil {
|
||||
logs.Warnf("failed to get object url for node %s: %v", nodeMeta.Name, err)
|
||||
return err
|
||||
}
|
||||
nodeTypeStr := entity.IDStrToNodeType(strconv.FormatInt(nodeMeta.ID, 10))
|
||||
if len(nodeTypeStr) > 0 {
|
||||
nodeIconURLCacheMu.Lock()
|
||||
nodeIconURLCache[string(nodeTypeStr)] = url
|
||||
nodeIconURLCacheMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.Infof("node icon url cache initialized with %d entries", len(nodeIconURLCache))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workflow.NodeTemplateListRequest) (
|
||||
_ *workflow.NodeTemplateListResponse, err error,
|
||||
) {
|
||||
@ -165,22 +119,19 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
|
||||
Name: category,
|
||||
}
|
||||
for _, nodeMeta := range nodeMetaList {
|
||||
nodeID := fmt.Sprintf("%d", nodeMeta.ID)
|
||||
nodeType := entity.IDStrToNodeType(nodeID)
|
||||
url := nodeIconURLCache[string(nodeType)]
|
||||
tpl := &workflow.NodeTemplate{
|
||||
ID: nodeID,
|
||||
ID: fmt.Sprintf("%d", nodeMeta.ID),
|
||||
Type: workflow.NodeTemplateType(nodeMeta.ID),
|
||||
Name: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSName, nodeMeta.Name),
|
||||
Desc: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSDescription, nodeMeta.Desc),
|
||||
IconURL: url,
|
||||
IconURL: nodeMeta.IconURL,
|
||||
SupportBatch: ternary.IFElse(nodeMeta.SupportBatch, workflow.SupportBatch_SUPPORT, workflow.SupportBatch_NOT_SUPPORT),
|
||||
NodeType: nodeID,
|
||||
NodeType: fmt.Sprintf("%d", nodeMeta.ID),
|
||||
Color: nodeMeta.Color,
|
||||
}
|
||||
|
||||
resp.Data.TemplateList = append(resp.Data.TemplateList, tpl)
|
||||
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, nodeID)
|
||||
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", nodeMeta.ID))
|
||||
}
|
||||
}
|
||||
|
||||
@ -799,13 +750,11 @@ func (w *ApplicationService) GetProcess(ctx context.Context, req *workflow.GetWo
|
||||
}
|
||||
}
|
||||
|
||||
iconURL := nodeIconURLCache[string(ie.NodeType)]
|
||||
|
||||
resp.Data.NodeEvents = append(resp.Data.NodeEvents, &workflow.NodeEvent{
|
||||
ID: strconv.FormatInt(ie.ID, 10),
|
||||
NodeID: string(ie.NodeKey),
|
||||
NodeTitle: ie.NodeTitle,
|
||||
NodeIcon: iconURL,
|
||||
NodeIcon: ie.NodeIcon,
|
||||
Data: ie.InterruptData,
|
||||
Type: ie.EventType,
|
||||
SchemaNodeID: string(ie.NodeKey),
|
||||
|
||||
@ -30,6 +30,7 @@ type Message interface {
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error)
|
||||
PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error)
|
||||
Create(ctx context.Context, msg *message.Message) (*message.Message, error)
|
||||
BatchCreate(ctx context.Context, msg []*message.Message) ([]*message.Message, error)
|
||||
List(ctx context.Context, meta *entity.ListMeta) (*entity.ListResult, error)
|
||||
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
Edit(ctx context.Context, msg *message.Message) (*message.Message, error)
|
||||
@ -58,7 +59,7 @@ type MessageListRequest struct {
|
||||
BeforeID *string
|
||||
AfterID *string
|
||||
UserID int64
|
||||
BizID int64
|
||||
AppID int64
|
||||
OrderBy *string
|
||||
}
|
||||
|
||||
@ -88,7 +89,7 @@ type WfMessage struct {
|
||||
type GetLatestRunIDsRequest struct {
|
||||
ConversationID int64
|
||||
UserID int64
|
||||
BizID int64
|
||||
AppID int64
|
||||
Rounds int64
|
||||
SectionID int64
|
||||
InitRunID *int64
|
||||
|
||||
@ -59,6 +59,21 @@ func (m *MockMessage) EXPECT() *MockMessageMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BatchCreate mocks base method.
|
||||
func (m *MockMessage) BatchCreate(ctx context.Context, msg []*message.Message) ([]*message.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchCreate", ctx, msg)
|
||||
ret0, _ := ret[0].([]*message.Message)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BatchCreate indicates an expected call of BatchCreate.
|
||||
func (mr *MockMessageMockRecorder) BatchCreate(ctx, msg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreate", reflect.TypeOf((*MockMessage)(nil).BatchCreate), ctx, msg)
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockMessage) Create(ctx context.Context, msg *message.Message) (*message.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -24,7 +24,6 @@ import (
|
||||
|
||||
var defaultSVC Uploader
|
||||
|
||||
//go:generate mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
|
||||
type Uploader interface {
|
||||
GetFile(ctx context.Context, req *service.GetFileRequest) (resp *service.GetFileResponse, err error)
|
||||
}
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: upload.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
|
||||
//
|
||||
|
||||
// Package uploadmock is a generated GoMock package.
|
||||
package uploadmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
service "github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUploader is a mock of Uploader interface.
|
||||
type MockUploader struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUploaderMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockUploaderMockRecorder is the mock recorder for MockUploader.
|
||||
type MockUploaderMockRecorder struct {
|
||||
mock *MockUploader
|
||||
}
|
||||
|
||||
// NewMockUploader creates a new mock instance.
|
||||
func NewMockUploader(ctrl *gomock.Controller) *MockUploader {
|
||||
mock := &MockUploader{ctrl: ctrl}
|
||||
mock.recorder = &MockUploaderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockUploader) EXPECT() *MockUploaderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetFile mocks base method.
|
||||
func (m *MockUploader) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFile", ctx, req)
|
||||
ret0, _ := ret[0].(*service.GetFileResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFile indicates an expected call of GetFile.
|
||||
func (mr *MockUploaderMockRecorder) GetFile(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploader)(nil).GetFile), ctx, req)
|
||||
}
|
||||
@ -54,7 +54,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
|
||||
ConversationID: req.ConversationID,
|
||||
Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe
|
||||
UserID: strconv.FormatInt(req.UserID, 10),
|
||||
AgentID: req.BizID,
|
||||
AgentID: req.AppID,
|
||||
OrderBy: req.OrderBy,
|
||||
}
|
||||
if req.BeforeID != nil {
|
||||
@ -96,7 +96,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
|
||||
func (c *impl) GetLatestRunIDs(ctx context.Context, req *crossmessage.GetLatestRunIDsRequest) ([]int64, error) {
|
||||
listMeta := &agententity.ListRunRecordMeta{
|
||||
ConversationID: req.ConversationID,
|
||||
AgentID: req.BizID,
|
||||
AgentID: req.AppID,
|
||||
Limit: int32(req.Rounds),
|
||||
SectionID: req.SectionID,
|
||||
}
|
||||
@ -170,6 +170,9 @@ func (c *impl) GetMessageByID(ctx context.Context, id int64) (*entity.Message, e
|
||||
func (c *impl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
|
||||
return c.DomainSVC.ListWithoutPair(ctx, req)
|
||||
}
|
||||
func (c *impl) BatchCreate(ctx context.Context, msgs []*entity.Message) ([]*entity.Message, error) {
|
||||
return c.DomainSVC.BatchCreate(ctx, msgs)
|
||||
}
|
||||
|
||||
func convertToConvAndSchemaMessage(ctx context.Context, msgs []*entity.Message) ([]*crossmessage.WfMessage, []*schema.Message, error) {
|
||||
messages := make([]*schema.Message, 0)
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../../internal/mock/domain/agent/singleagent/single_agent_mock.go --package singleagent -source single_agent.go
|
||||
type SingleAgent interface {
|
||||
// draft agent
|
||||
CreateSingleAgentDraft(ctx context.Context, creatorID int64, draft *entity.SingleAgent) (agentID int64, err error)
|
||||
|
||||
@ -106,24 +106,34 @@ type MetaInfo struct {
|
||||
}
|
||||
|
||||
type AgentRunMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
SpaceID int64 `json:"space_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
Name string `json:"name"`
|
||||
UserID string `json:"user_id"`
|
||||
CozeUID int64 `json:"coze_uid"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
ContentType message.ContentType `json:"content_type"`
|
||||
Content []*message.InputMetaData `json:"content"`
|
||||
PreRetrieveTools []*Tool `json:"tools"`
|
||||
IsDraft bool `json:"is_draft"`
|
||||
CustomerConfig *CustomerConfig `json:"customer_config"`
|
||||
DisplayContent string `json:"display_content"`
|
||||
CustomVariables map[string]string `json:"custom_variables"`
|
||||
Version string `json:"version"`
|
||||
Ext map[string]string `json:"ext"`
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
SpaceID int64 `json:"space_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
Name string `json:"name"`
|
||||
UserID string `json:"user_id"`
|
||||
CozeUID int64 `json:"coze_uid"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
ContentType message.ContentType `json:"content_type"`
|
||||
Content []*message.InputMetaData `json:"content"`
|
||||
PreRetrieveTools []*Tool `json:"tools"`
|
||||
IsDraft bool `json:"is_draft"`
|
||||
CustomerConfig *CustomerConfig `json:"customer_config"`
|
||||
DisplayContent string `json:"display_content"`
|
||||
CustomVariables map[string]string `json:"custom_variables"`
|
||||
Version string `json:"version"`
|
||||
Ext map[string]string `json:"ext"`
|
||||
AdditionalMessages []*AdditionalMessage `json:"additional_messages"`
|
||||
}
|
||||
|
||||
type AdditionalMessage struct {
|
||||
Role schema.RoleType `json:"role"`
|
||||
Type message.MessageType `json:"type"`
|
||||
Content []*message.InputMetaData `json:"content"`
|
||||
ContentType message.ContentType `json:"content_type"`
|
||||
Name *string `json:"name"`
|
||||
Meta map[string]string `json:"meta"`
|
||||
}
|
||||
|
||||
type UpdateMeta struct {
|
||||
|
||||
@ -41,7 +41,7 @@ import (
|
||||
|
||||
func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) (err error) {
|
||||
|
||||
mh := &MesssageEventHanlder{
|
||||
mh := &MessageEventHandler{
|
||||
sw: art.SW,
|
||||
messageEvent: art.MessageEvent,
|
||||
}
|
||||
@ -110,7 +110,7 @@ func concatWfInput(rtDependence *AgentRuntime) string {
|
||||
return strings.Trim(input, ",")
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MesssageEventHanlder) {
|
||||
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MessageEventHandler) {
|
||||
|
||||
fullAnswerContent := bytes.NewBuffer([]byte{})
|
||||
var usage *msgEntity.UsageExt
|
||||
|
||||
@ -221,6 +221,51 @@ func preCreateAnswer(ctx context.Context, rtDependence *AgentRuntime) (*msgEntit
|
||||
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
|
||||
}
|
||||
|
||||
func buildAdditionalMessage2Create(ctx context.Context, runRecord *entity.RunRecordMeta, additionalMessage *entity.AdditionalMessage, userID string) *message.Message {
|
||||
|
||||
msg := &msgEntity.Message{
|
||||
ConversationID: runRecord.ConversationID,
|
||||
RunID: runRecord.ID,
|
||||
AgentID: runRecord.AgentID,
|
||||
SectionID: runRecord.SectionID,
|
||||
UserID: userID,
|
||||
MessageType: additionalMessage.Type,
|
||||
}
|
||||
|
||||
switch additionalMessage.Type {
|
||||
case message.MessageTypeQuestion:
|
||||
msg.Role = schema.User
|
||||
msg.ContentType = additionalMessage.ContentType
|
||||
for _, content := range additionalMessage.Content {
|
||||
if content.Type == message.InputTypeText {
|
||||
msg.Content = content.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
msg.MultiContent = additionalMessage.Content
|
||||
|
||||
case message.MessageTypeAnswer:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
for _, content := range additionalMessage.Content {
|
||||
if content.Type == message.InputTypeText {
|
||||
msg.Content = content.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
modelContent := &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: msg.Content,
|
||||
}
|
||||
|
||||
jsonContent, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(jsonContent)
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *AgentRuntime) *message.Message {
|
||||
arm := rtDependence.GetRunMeta()
|
||||
msg := &msgEntity.Message{
|
||||
|
||||
@ -98,12 +98,12 @@ func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResp
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
type MesssageEventHanlder struct {
|
||||
type MessageEventHandler struct {
|
||||
messageEvent *Event
|
||||
sw *schema.StreamWriter[*entity.AgentRunResponse]
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
|
||||
func (mh *MessageEventHandler) handlerErr(_ context.Context, err error) {
|
||||
|
||||
var errMsg string
|
||||
var statusErr errorx.StatusError
|
||||
@ -123,7 +123,7 @@ func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
|
||||
func (mh *MessageEventHandler) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
|
||||
sendMsg := &entity.ChunkMessageItem{
|
||||
ID: input.ID,
|
||||
ConversationID: input.ConversationID,
|
||||
@ -142,7 +142,7 @@ func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgE
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
func (mh *MessageEventHandler) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
|
||||
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
@ -156,7 +156,7 @@ func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
|
||||
func (mh *MessageEventHandler) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
|
||||
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
|
||||
|
||||
@ -184,7 +184,7 @@ func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
func (mh *MessageEventHandler) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
|
||||
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
@ -199,7 +199,7 @@ func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entit
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
func (mh *MessageEventHandler) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
@ -212,7 +212,7 @@ func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *ent
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
|
||||
func (mh *MessageEventHandler) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
|
||||
|
||||
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
|
||||
return nil
|
||||
@ -265,7 +265,7 @@ func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.C
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
|
||||
func (mh *MessageEventHandler) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
@ -278,7 +278,7 @@ func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
func (mh *MessageEventHandler) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
@ -291,7 +291,7 @@ func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chu
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
|
||||
func (mh *MessageEventHandler) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
|
||||
|
||||
if msg.Ext == nil {
|
||||
msg.Ext = map[string]string{}
|
||||
@ -314,7 +314,7 @@ func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
|
||||
func (mh *MessageEventHandler) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
|
||||
interruptData, cType, err := parseInterruptData(ctx, chunk.Interrupt)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -366,7 +366,7 @@ func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *ent
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
|
||||
func (mh *MessageEventHandler) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
|
||||
interruptData, cType, err := handlerWfInterruptEvent(ctx, stateMsg.InterruptEvent)
|
||||
if err != nil {
|
||||
return
|
||||
@ -412,7 +412,7 @@ func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, state
|
||||
}
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
|
||||
func (mh *MessageEventHandler) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
|
||||
msgMeta := buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
|
||||
|
||||
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
|
||||
@ -426,3 +426,21 @@ func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
func (mh *MessageEventHandler) ParseAdditionalMessages(ctx context.Context, rtDependence *AgentRuntime, runRecord *entity.RunRecordMeta) error {
|
||||
|
||||
if len(rtDependence.GetRunMeta().AdditionalMessages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
additionalMessages := make([]*message.Message, 0, len(rtDependence.GetRunMeta().AdditionalMessages))
|
||||
|
||||
for _, msg := range rtDependence.GetRunMeta().AdditionalMessages {
|
||||
cm := buildAdditionalMessage2Create(ctx, runRecord, msg, rtDependence.GetRunMeta().UserID)
|
||||
additionalMessages = append(additionalMessages, cm)
|
||||
}
|
||||
|
||||
_, err := crossmessage.DefaultSVC().BatchCreate(ctx, additionalMessages)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -107,6 +107,11 @@ func (rd *AgentRuntime) GetHistory() []*msgEntity.Message {
|
||||
|
||||
func (art *AgentRuntime) Run(ctx context.Context) (err error) {
|
||||
|
||||
mh := &MessageEventHandler{
|
||||
messageEvent: art.MessageEvent,
|
||||
sw: art.SW,
|
||||
}
|
||||
|
||||
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft, art.GetRunMeta().ConnectorID)
|
||||
if err != nil {
|
||||
return
|
||||
@ -114,6 +119,18 @@ func (art *AgentRuntime) Run(ctx context.Context) (err error) {
|
||||
|
||||
art.SetAgentInfo(agentInfo)
|
||||
|
||||
if len(art.GetRunMeta().AdditionalMessages) > 0 {
|
||||
var additionalRunRecord *entity.RunRecordMeta
|
||||
additionalRunRecord, err = art.RunRecordRepo.Create(ctx, art.GetRunMeta())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = mh.ParseAdditionalMessages(ctx, art, additionalRunRecord)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
history, err := art.getHistory(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
@ -140,10 +157,7 @@ func (art *AgentRuntime) Run(ctx context.Context) (err error) {
|
||||
}
|
||||
art.RunProcess.StepToComplete(ctx, srRecord, art.SW, art.GetUsage())
|
||||
}()
|
||||
mh := &MesssageEventHanlder{
|
||||
messageEvent: art.MessageEvent,
|
||||
sw: art.SW,
|
||||
}
|
||||
|
||||
input, err := mh.HandlerInput(ctx, art)
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
@ -80,7 +80,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I
|
||||
|
||||
func (art *AgentRuntime) push(ctx context.Context, mainChan chan *entity.AgentRespEvent) {
|
||||
|
||||
mh := &MesssageEventHanlder{
|
||||
mh := &MessageEventHandler{
|
||||
sw: art.SW,
|
||||
messageEvent: art.MessageEvent,
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../../internal/mock/domain/conversation/agentrun/agent_run_mock.go --package agentrun -source agent_run.go
|
||||
type Run interface {
|
||||
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
|
||||
Delete(ctx context.Context, runID []int64) error
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../../internal/mock/domain/conversation/conversation/conversation_mock.go --package conversation -source conversation.go
|
||||
type Conversation interface {
|
||||
Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
|
||||
|
||||
@ -72,6 +72,25 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity
|
||||
return dao.messagePO2DO(poData), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error) {
|
||||
poList := make([]*model.Message, 0, len(msg))
|
||||
for _, m := range msg {
|
||||
po, err := dao.messageDO2PO(ctx, m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
poList = append(poList, po)
|
||||
}
|
||||
|
||||
do := dao.query.Message.WithContext(ctx).Debug()
|
||||
cErr := do.CreateInBatches(poList, len(poList))
|
||||
if cErr != nil {
|
||||
return nil, cErr
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(poList), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
@ -34,6 +34,7 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
|
||||
type MessageRepo interface {
|
||||
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error)
|
||||
List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error)
|
||||
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
|
||||
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
|
||||
|
||||
@ -27,6 +27,7 @@ type Message interface {
|
||||
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
BatchCreate(ctx context.Context, req []*entity.Message) ([]*entity.Message, error)
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Message, error)
|
||||
Edit(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
|
||||
@ -124,6 +124,10 @@ func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, e
|
||||
return m.MessageRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (m *messageImpl) BatchCreate(ctx context.Context, req []*entity.Message) ([]*entity.Message, error) {
|
||||
return m.MessageRepo.BatchCreate(ctx, req)
|
||||
}
|
||||
|
||||
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {
|
||||
|
||||
_, err := m.MessageRepo.Edit(ctx, req.ID, &message.Message{
|
||||
|
||||
@ -494,3 +494,55 @@ func TestListWithoutPair(t *testing.T) {
|
||||
assert.Equal(t, "Answer message", resp.Messages[0].Content)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBatchCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Message{})
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
|
||||
t.Run("success_single_message", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
// 准备测试数据
|
||||
inputMsgs := []*entity.Message{
|
||||
{
|
||||
ID: 1,
|
||||
ConversationID: 100,
|
||||
RunID: 200,
|
||||
AgentID: 300,
|
||||
UserID: "user123",
|
||||
Content: "Hello World",
|
||||
Role: schema.User,
|
||||
ContentType: message.ContentTypeText,
|
||||
MessageType: message.MessageTypeQuestion,
|
||||
Status: message.MessageStatusAvailable,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
ConversationID: 100,
|
||||
RunID: 200,
|
||||
AgentID: 300,
|
||||
UserID: "user123",
|
||||
Content: "Hello World",
|
||||
Role: schema.Assistant,
|
||||
ContentType: message.ContentTypeText,
|
||||
MessageType: message.MessageTypeQuestion,
|
||||
Status: message.MessageStatusAvailable,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := NewService(components).BatchCreate(ctx, inputMsgs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, inputMsgs[1].ID, result[1].ID)
|
||||
})
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/domain/shortcutcmd/shortcut_cmd_mock.go --package shortcutcmd -source shortcut_cmd.go
|
||||
type ShortcutCmd interface {
|
||||
ListCMD(ctx context.Context, lm *entity.ListMeta) ([]*entity.ShortcutCmd, error)
|
||||
CreateCMD(ctx context.Context, shortcut *entity.ShortcutCmd) (*entity.ShortcutCmd, error)
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/upload/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/domain/upload/upload_service_mock.go --package upload -source interface.go
|
||||
type UploadService interface {
|
||||
UploadFile(ctx context.Context, req *UploadFileRequest) (resp *UploadFileResponse, err error)
|
||||
UploadFiles(ctx context.Context, req *UploadFilesRequest) (resp *UploadFilesResponse, err error)
|
||||
|
||||
@ -75,11 +75,11 @@ type Conversation interface {
|
||||
ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error)
|
||||
ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error
|
||||
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
|
||||
GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error)
|
||||
GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error)
|
||||
UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error)
|
||||
GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error)
|
||||
GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error)
|
||||
GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
|
||||
GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
|
||||
}
|
||||
|
||||
type InterruptEventStore interface {
|
||||
@ -143,8 +143,8 @@ type ConversationRepository interface {
|
||||
UpdateStaticConversation(ctx context.Context, env vo.Env, templateID int64, connectorID int64, userID int64, newConversationID int64) error
|
||||
UpdateDynamicConversation(ctx context.Context, env vo.Env, conversationID, newConversationID int64) error
|
||||
CopyTemplateConversationByAppID(ctx context.Context, appID int64, toAppID int64) error
|
||||
GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
|
||||
GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
|
||||
GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
|
||||
GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
|
||||
}
|
||||
type WorkflowConfig interface {
|
||||
GetNodeOfCodeConfig() *config.NodeOfCodeConfig
|
||||
|
||||
@ -53,7 +53,7 @@ type NodeTypeMeta struct {
|
||||
Category string `json:"category"`
|
||||
Color string `json:"color"`
|
||||
Desc string `json:"desc"`
|
||||
IconURI string `json:"icon_uri"`
|
||||
IconURL string `json:"icon_url"`
|
||||
SupportBatch bool `json:"support_batch"`
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
EnUSName string `json:"en_us_name,omitempty"`
|
||||
@ -265,7 +265,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "input&output",
|
||||
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
Color: "#5C62FF",
|
||||
IconURI: "default_icon/workflow_icon/icon-start.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
@ -281,7 +281,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "input&output",
|
||||
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
Color: "#5C62FF",
|
||||
IconURI: "default_icon/workflow_icon/icon-end.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -299,7 +299,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "",
|
||||
Desc: "调用大语言模型,使用变量和提示词生成回复",
|
||||
Color: "#5C62FF",
|
||||
IconURI: "default_icon/workflow_icon/icon-llm.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
SupportBatch: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -319,7 +319,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "",
|
||||
Desc: "通过添加工具访问实时数据和执行外部操作",
|
||||
Color: "#CA61FF",
|
||||
IconURI: "default_icon/workflow_icon/icon-plugin.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
|
||||
SupportBatch: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -337,7 +337,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "编写代码,处理输入变量来生成返回值",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-code.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -355,7 +355,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "data",
|
||||
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
|
||||
Color: "#FF811A",
|
||||
IconURI: "default_icon/workflow_icon/icon-knowledge-query.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -374,7 +374,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-condition.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Condition",
|
||||
@ -388,7 +388,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "",
|
||||
Desc: "集成已发布工作流,可以执行嵌套子任务",
|
||||
Color: "#00B83E",
|
||||
IconURI: "default_icon/workflow_icon/icon-workflow.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
|
||||
SupportBatch: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
BlockEndStream: true,
|
||||
@ -404,7 +404,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "database",
|
||||
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
|
||||
Color: "#FF811A",
|
||||
IconURI: "default_icon/workflow_icon/icon-database.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -422,7 +422,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "input&output",
|
||||
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
|
||||
Color: "#5C62FF",
|
||||
IconURI: "default_icon/workflow_icon/icon-output.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -441,7 +441,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "utilities",
|
||||
Desc: "用于处理多个字符串类型变量的格式",
|
||||
Color: "#3071F2",
|
||||
IconURI: "default_icon/workflow_icon/icon-text.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -458,7 +458,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "utilities",
|
||||
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
Color: "#3071F2",
|
||||
IconURI: "default_icon/workflow_icon/icon-question.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -477,7 +477,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "用于立即终止当前所在的循环,跳出循环体",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-break.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Break",
|
||||
@ -491,7 +491,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-loop-set-variable.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Set Variable",
|
||||
@ -505,7 +505,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-loop.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
IsComposite: true,
|
||||
@ -524,7 +524,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-intent.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -543,7 +543,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "data",
|
||||
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
|
||||
Color: "#FF811A",
|
||||
IconURI: "default_icon/workflow_icon/icon-knowledge-write.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -561,7 +561,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-batch.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
IsComposite: true,
|
||||
@ -580,7 +580,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "用于终止当前循环,执行下次循环",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-continue.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Continue",
|
||||
@ -594,7 +594,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "input&output",
|
||||
Desc: "支持中间过程的信息输入",
|
||||
Color: "#5C62FF",
|
||||
IconURI: "default_icon/workflow_icon/icon_input.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
@ -609,6 +609,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "", // Not found in cate_list
|
||||
Desc: "comment_desc", // Placeholder from JSON
|
||||
Color: "",
|
||||
IconURL: "comment_icon", // Placeholder from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
EnUSName: "Comment",
|
||||
},
|
||||
@ -619,7 +620,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "logic",
|
||||
Desc: "对多个分支的输出进行聚合处理",
|
||||
Color: "#00B2B2",
|
||||
IconURI: "default_icon/workflow_icon/icon-variable-merge.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
@ -637,7 +638,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "message",
|
||||
Desc: "用于查询消息列表",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-query-message-list.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -653,7 +654,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "conversation_history", // Mapped from cate_list
|
||||
Desc: "用于清空会话历史,清空后LLM看到的会话历史为空",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-clear-context.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -669,7 +670,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "conversation_management",
|
||||
Desc: "用于创建会话",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-create-conversation.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -686,7 +687,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "data",
|
||||
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
|
||||
Color: "#FF811A",
|
||||
IconURI: "default_icon/workflow_icon/icon-variable-assign.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Variable assign",
|
||||
@ -700,7 +701,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "database",
|
||||
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-database-update.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -717,7 +718,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "database",
|
||||
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-database-query.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -734,7 +735,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "database",
|
||||
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-database-delete.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -751,7 +752,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "utilities",
|
||||
Desc: "用于发送API请求,从接口返回数据",
|
||||
Color: "#3071F2",
|
||||
IconURI: "default_icon/workflow_icon/icon-http.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -768,7 +769,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "database",
|
||||
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-database-create.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -784,7 +785,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "conversation_management",
|
||||
Desc: "用于修改会话的名字",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-update-conversation.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-编辑会话.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -801,7 +802,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "conversation_management",
|
||||
Desc: "用于删除会话",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-delete-conversation.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-删除会话.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -817,7 +818,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "conversation_management",
|
||||
Desc: "用于查询所有会话,包含静态会话、动态会话",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-query-conversation-list.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-查询会话.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
@ -832,7 +833,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "conversation_history", // Mapped from cate_list
|
||||
Desc: "用于查询会话历史,返回LLM可见的会话消息",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-query-conversation-history.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-查询会话历史.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -848,7 +849,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "message",
|
||||
Desc: "用于创建消息",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-create-message.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-创建消息.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -864,7 +865,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "message",
|
||||
Desc: "用于修改消息",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-update-message.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-修改消息.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -880,7 +881,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "message",
|
||||
Desc: "用于删除消息",
|
||||
Color: "#F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-delete-message.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-删除消息.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -915,8 +916,8 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
// Color is the color of the upper edge of the node displayed on Canvas.
|
||||
Color: "F2B600",
|
||||
|
||||
// IconURI is the resource identifier for the icon displayed on the Canvas. It's resolved into a full URL by the backend to support different deployment environments.
|
||||
IconURI: "default_icon/workflow_icon/icon-json-stringify.jpg",
|
||||
// IconURL is the URL of the icon displayed on Canvas.
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
|
||||
|
||||
// SupportBatch indicates whether this node can set batch mode.
|
||||
// NOTE: ultimately it's frontend that decides which node can enable batch mode.
|
||||
@ -942,7 +943,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "utilities",
|
||||
Desc: "用于将JSON字符串解析为变量",
|
||||
Color: "F2B600",
|
||||
IconURI: "default_icon/workflow_icon/icon-json-parser.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
@ -960,7 +961,7 @@ var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
Category: "data",
|
||||
Desc: "用于删除知识库中的文档",
|
||||
Color: "#FF811A",
|
||||
IconURI: "default_icon/workflow_icon/icon-knowledge-delete.jpg",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
|
||||
@ -32,11 +32,6 @@ const (
|
||||
ChatFlowMessageCompleted ChatFlowEvent = "conversation.message.completed"
|
||||
)
|
||||
|
||||
const (
|
||||
ConversationNameKey = "CONVERSATION_NAME"
|
||||
UserInputKey = "USER_INPUT"
|
||||
)
|
||||
|
||||
type Usage struct {
|
||||
TokenCount *int32 `form:"token_count" json:"token_count,omitempty"`
|
||||
OutputTokens *int32 `form:"output_count" json:"output_count,omitempty"`
|
||||
|
||||
@ -59,14 +59,14 @@ type ListConversationPolicy struct {
|
||||
}
|
||||
|
||||
type CreateStaticConversation struct {
|
||||
BizID int64
|
||||
AppID int64
|
||||
UserID int64
|
||||
ConnectorID int64
|
||||
|
||||
TemplateID int64
|
||||
}
|
||||
type CreateDynamicConversation struct {
|
||||
BizID int64
|
||||
AppID int64
|
||||
UserID int64
|
||||
ConnectorID int64
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ package vo
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
@ -43,14 +44,8 @@ type Reference struct {
|
||||
}
|
||||
|
||||
type FieldSource struct {
|
||||
Ref *Reference `json:"ref,omitempty"`
|
||||
Val any `json:"val,omitempty"`
|
||||
FileExtra *FileExtra `json:"file_extra,omitempty"`
|
||||
}
|
||||
|
||||
type FileExtra struct {
|
||||
FileName *string `json:"file_name,omitempty"`
|
||||
FileNames []string `json:"file_names,omitempty"`
|
||||
Ref *Reference `json:"ref,omitempty"`
|
||||
Val any `json:"val,omitempty"`
|
||||
}
|
||||
|
||||
type TypeInfo struct {
|
||||
|
||||
@ -731,20 +731,6 @@ func TestKnowledgeDeleter(t *testing.T) {
|
||||
UserID: 123,
|
||||
})
|
||||
|
||||
defer mockey.Mock(execute.GetExeCtx).Return(&execute.Context{
|
||||
RootCtx: execute.RootCtx{
|
||||
ExeCfg: workflowModel.ExecuteConfig{
|
||||
InputFileFields: map[string]*workflowModel.FileInfo{
|
||||
"https://p26-bot-workflow-sign.byteimg.com/tos-cn-i-mdko3gqilj/5264fa1295da4a6483cd236b1316c454.pdf~tplv-mdko3gqilj-image.image?rk3s=81d4c505&x-expires=1782379180&x-signature=mlaXPIk9VJjOXu87xGaRmNRg9%2BA%3D": &workflowModel.FileInfo{
|
||||
FileName: "1706.03762v7.pdf",
|
||||
FileURL: "https://p26-bot-workflow-sign.byteimg.com/tos-cn-i-mdko3gqilj/5264fa1295da4a6483cd236b1316c454.pdf~tplv-mdko3gqilj-image.image?rk3s=81d4c505&x-expires=1782379180&x-signature=mlaXPIk9VJjOXu87xGaRmNRg9%2BA%3D",
|
||||
FileExtension: ".pdf",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).Build().UnPatch()
|
||||
|
||||
workflowSC, err := CanvasToWorkflowSchema(ctx, c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
||||
@ -235,7 +235,6 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
||||
if enabled {
|
||||
trimmedSC.GeneratedNodes = append(trimmedSC.GeneratedNodes, ns.Key)
|
||||
}
|
||||
trimmedSC.Init()
|
||||
|
||||
return trimmedSC, nil
|
||||
}
|
||||
|
||||
@ -446,7 +446,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node)
|
||||
|
||||
func parseBatchMode(n *vo.Node) (
|
||||
batchN *vo.Node, // the new batch node
|
||||
enabled bool, // whether the node has enabled batch mode
|
||||
enabled bool, // whether the node has enabled batch mode
|
||||
err error) {
|
||||
if n.Data == nil || n.Data.Inputs == nil {
|
||||
return nil, false, nil
|
||||
|
||||
@ -18,7 +18,6 @@ package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -26,7 +25,6 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@ -182,10 +180,7 @@ func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, p
|
||||
if value == nil {
|
||||
return nil, fmt.Errorf("input %v has no value, type= %s", path, b.Type)
|
||||
}
|
||||
var fileExtra *vo.FileExtra
|
||||
isFileAssistType := func(assistType vo.AssistType) bool {
|
||||
return assistType >= vo.AssistTypeDefault && assistType <= vo.AssistTypeVoice
|
||||
}
|
||||
|
||||
switch value.Type {
|
||||
case vo.BlockInputValueTypeObjectRef:
|
||||
sc := b.Schema
|
||||
@ -219,6 +214,7 @@ func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, p
|
||||
if content == nil {
|
||||
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
|
||||
}
|
||||
|
||||
switch b.Type {
|
||||
case vo.VariableTypeObject:
|
||||
m := make(map[string]any)
|
||||
@ -227,43 +223,11 @@ func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, p
|
||||
}
|
||||
content = m
|
||||
case vo.VariableTypeList:
|
||||
switch content.(type) {
|
||||
case string:
|
||||
if _, ok := content.(string); ok {
|
||||
l := make([]any, 0)
|
||||
if err = sonic.UnmarshalString(content.(string), &l); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = l
|
||||
}
|
||||
case []string:
|
||||
content = content.([]string)
|
||||
case []any:
|
||||
content = content.([]any)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type fot list: %s", b.Type)
|
||||
}
|
||||
eleSchema, err := vo.ParseVariable(b.Schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can not parse schema from %v", b.Schema)
|
||||
}
|
||||
|
||||
if isFileAssistType(eleSchema.AssistType) {
|
||||
rawMeta, ok := b.Value.RawMeta.(map[string]any)
|
||||
if ok {
|
||||
filenames, ok := rawMeta["fileName"].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not get filename from %v", rawMeta)
|
||||
}
|
||||
fileExtra = &vo.FileExtra{
|
||||
FileNames: make([]string, 0, len(filenames)),
|
||||
}
|
||||
for _, filename := range filenames {
|
||||
fileExtra.FileNames = append(fileExtra.FileNames, filename.(string))
|
||||
}
|
||||
}
|
||||
|
||||
l := make([]any, 0)
|
||||
if err = sonic.UnmarshalString(content.(string), &l); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = l
|
||||
case vo.VariableTypeInteger:
|
||||
switch content.(type) {
|
||||
case string:
|
||||
@ -304,27 +268,13 @@ func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, p
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type for boolean: %s", b.Type)
|
||||
}
|
||||
case vo.VariableTypeString:
|
||||
if isFileAssistType(b.AssistType) {
|
||||
rawMeta, ok := b.Value.RawMeta.(map[string]any)
|
||||
if ok {
|
||||
filename, ok := rawMeta["fileName"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not get filename from %v", rawMeta)
|
||||
}
|
||||
fileExtra = &vo.FileExtra{
|
||||
FileName: ptr.Of(filename),
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
return []*vo.FieldInfo{
|
||||
{
|
||||
Path: path,
|
||||
Source: vo.FieldSource{
|
||||
Val: content,
|
||||
FileExtra: fileExtra,
|
||||
Val: content,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
@ -516,8 +466,8 @@ func SetInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -1,275 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 180,
|
||||
"y": 79.2
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "USER_INPUT",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"defaultValue": "dhl"
|
||||
}
|
||||
],
|
||||
"nodeMeta": {
|
||||
"title": "开始",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"subTitle": ""
|
||||
},
|
||||
"trigger_parameters": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 2020,
|
||||
"y": 66.2
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"title": "结束",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"subTitle": ""
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "useAnswerContent",
|
||||
"streamingOutput": true,
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "142077",
|
||||
"name": "optionContent"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "{{output}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "190196",
|
||||
"type": "30",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 640,
|
||||
"y": 78.5
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"nodeMeta": {
|
||||
"title": "输入",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
|
||||
"description": "支持中间过程的信息输入",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "输入"
|
||||
},
|
||||
"inputs": {
|
||||
"outputSchema": "[{\"type\":\"string\",\"name\":\"input\",\"required\":true}]"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "133775",
|
||||
"type": "18",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1100,
|
||||
"y": 39
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"inputs": {
|
||||
"llmParam": {
|
||||
"modelType": 1001,
|
||||
"modelName": "Doubao-Seed-1.6",
|
||||
"generationDiversity": "balance",
|
||||
"temperature": 0.8,
|
||||
"maxTokens": 4096,
|
||||
"topP": 0.7,
|
||||
"responseFormat": 2,
|
||||
"systemPrompt": ""
|
||||
},
|
||||
"inputParameters": [],
|
||||
"extra_output": false,
|
||||
"answer_type": "text",
|
||||
"option_type": "static",
|
||||
"dynamic_option": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"question": "你好",
|
||||
"options": [
|
||||
{
|
||||
"name": ""
|
||||
},
|
||||
{
|
||||
"name": ""
|
||||
}
|
||||
],
|
||||
"limit": 3
|
||||
},
|
||||
"nodeMeta": {
|
||||
"title": "问答",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
"description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
"mainColor": "#3071F2",
|
||||
"subTitle": "问答"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "USER_RESPONSE",
|
||||
"required": true,
|
||||
"description": "用户本轮对话输入内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "142077",
|
||||
"type": "18",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1560,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"inputs": {
|
||||
"llmParam": {
|
||||
"modelType": 1001,
|
||||
"modelName": "Doubao-Seed-1.6",
|
||||
"generationDiversity": "balance",
|
||||
"temperature": 0.8,
|
||||
"maxTokens": 4096,
|
||||
"topP": 0.7,
|
||||
"responseFormat": 2,
|
||||
"systemPrompt": ""
|
||||
},
|
||||
"inputParameters": [],
|
||||
"extra_output": false,
|
||||
"answer_type": "option",
|
||||
"option_type": "static",
|
||||
"dynamic_option": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"question": "请选择",
|
||||
"options": [
|
||||
{
|
||||
"name": "A"
|
||||
},
|
||||
{
|
||||
"name": "B"
|
||||
}
|
||||
],
|
||||
"limit": 3
|
||||
},
|
||||
"nodeMeta": {
|
||||
"title": "问答_1",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
"description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
"mainColor": "#3071F2",
|
||||
"subTitle": "问答"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "optionId",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "optionContent",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "190196"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "142077",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "branch_0"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "142077",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "branch_1"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "142077",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "default"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "190196",
|
||||
"targetNodeID": "133775"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "133775",
|
||||
"targetNodeID": "142077"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -1,397 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{output}}",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "123887",
|
||||
"name": "output",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "output"
|
||||
}
|
||||
],
|
||||
"streamingOutput": true,
|
||||
"terminatePlan": "useAnswerContent"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 926,
|
||||
"y": -13
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"fcParamVar": {
|
||||
"knowledgeFCParam": {}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "USER_INPUT",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"llmParam": [
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "1737521813",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modelType"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "豆包·1.5·Pro·32k",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modleName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "balance",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "generationDiversity"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "float",
|
||||
"value": {
|
||||
"content": "0.8",
|
||||
"rawMeta": {
|
||||
"type": 4
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "temperature"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "4096",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "maxTokens"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spCurrentTime"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spAntiLeak"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prefixCache"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "2",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "responseFormat"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{input}}",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "enableChatHistory"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "3",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "chatHistoryRound"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "systemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "stableSystemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "canContinue"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptVersion"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptId"
|
||||
}
|
||||
],
|
||||
"settingOnError": {
|
||||
"processType": 1,
|
||||
"retryTimes": 0,
|
||||
"timeoutMs": 180000
|
||||
}
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "调用大语言模型,使用变量和提示词生成回复",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "大模型",
|
||||
"title": "大模型"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"version": "3"
|
||||
},
|
||||
"edges": null,
|
||||
"id": "123887",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 463,
|
||||
"y": -39
|
||||
}
|
||||
},
|
||||
"type": "3"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "123887",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "123887",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -1,397 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{output}}",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "123887",
|
||||
"name": "output",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "output"
|
||||
}
|
||||
],
|
||||
"streamingOutput": true,
|
||||
"terminatePlan": "useAnswerContent"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 926,
|
||||
"y": -13
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"fcParamVar": {
|
||||
"knowledgeFCParam": {}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "USER_INPUT",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"llmParam": [
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "1737521813",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modelType"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "豆包·1.5·Pro·32k",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "modleName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "balance",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "generationDiversity"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "float",
|
||||
"value": {
|
||||
"content": "0.8",
|
||||
"rawMeta": {
|
||||
"type": 4
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "temperature"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "4096",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "maxTokens"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spCurrentTime"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "spAntiLeak"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prefixCache"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "2",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "responseFormat"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "{{input}}",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "prompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": true,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "enableChatHistory"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"content": "3",
|
||||
"rawMeta": {
|
||||
"type": 2
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "chatHistoryRound"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "systemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "stableSystemPrompt"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": false,
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "canContinue"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptVersion"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "loopPromptId"
|
||||
}
|
||||
],
|
||||
"settingOnError": {
|
||||
"processType": 1,
|
||||
"retryTimes": 0,
|
||||
"timeoutMs": 180000
|
||||
}
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "调用大语言模型,使用变量和提示词生成回复",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "大模型",
|
||||
"title": "大模型"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"version": "3"
|
||||
},
|
||||
"edges": null,
|
||||
"id": "123887",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 463,
|
||||
"y": -39
|
||||
}
|
||||
},
|
||||
"type": "3"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "123887",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "123887",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -1,344 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 180,
|
||||
"y": 26.700000000000003
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"assistType": 6,
|
||||
"name": "f",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"type": "list",
|
||||
"name": "fs",
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"assistType": 2
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 2020,
|
||||
"y": 13.700000000000003
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "returnVariables",
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "list",
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"assistType": 2
|
||||
},
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "100001",
|
||||
"name": "fs"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 104
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "filename",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "129362",
|
||||
"name": "fileName"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "191034",
|
||||
"type": "5",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1532.6813186813188,
|
||||
"y": -19.285714285714285
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"title": "代码",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
|
||||
"description": "编写代码,处理输入变量来生成返回值",
|
||||
"mainColor": "#00B2B2",
|
||||
"subTitle": "代码"
|
||||
},
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "input",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "129362",
|
||||
"name": "fileName"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"code": "# 在这里,您可以通过 'args' 获取节点中的输入变量,并通过 'ret' 输出结果\n# 'args' 已经被正确地注入到环境中\n# 下面是一个示例,首先获取节点的全部输入参数params,其次获取其中参数名为'input'的值:\n# params = args.params; \n# input = params['input'];\n# 下面是一个示例,输出一个包含多种数据类型的 'ret' 对象:\n# ret: Output = { \"name\": '小明', \"hobbies\": [\"看书\", \"旅游\"] };\n\nasync def main(args: Args) -> Output:\n params = args.params\n # 构建输出对象\n ret: Output = {\n \"key0\": params['input'] + params['input'], # 拼接两次入参 input 的值\n \"key1\": [\"hello\", \"world\"], # 输出一个数组\n \"key2\": { # 输出一个Object \n \"key21\": \"hi\"\n },\n }\n return ret",
|
||||
"language": 3,
|
||||
"settingOnError": {
|
||||
"processType": 1,
|
||||
"timeoutMs": 60000,
|
||||
"retryTimes": 0
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "key0"
|
||||
},
|
||||
{
|
||||
"type": "list",
|
||||
"name": "key1",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"name": "key2",
|
||||
"schema": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "key21"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "129362",
|
||||
"type": "27",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 640,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"title": "知识库写入",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
|
||||
"description": "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
|
||||
"mainColor": "#FF811A",
|
||||
"subTitle": "知识库写入"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "documentId"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "fileName"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "fileUrl"
|
||||
}
|
||||
],
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "knowledge",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"assistType": 1,
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/f34d246d-7179-4f0d-856d-5d4c135ffee5.txt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T074330Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=2208774d8ffc6f116c3f9851a184d851fbcb745ff78cd001359a873f19f6dad4&x-wf-file_name=%E5%8C%97%E4%BA%AC%E6%97%85%E6%B8%B8%E6%99%AF%E7%82%B9.txt",
|
||||
"rawMeta": {
|
||||
"fileName": "北京旅游景点.txt",
|
||||
"type": 8
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"datasetParam": [
|
||||
{
|
||||
"name": "datasetList",
|
||||
"input": {
|
||||
"type": "list",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": ["7548363002819379200"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"strategyParam": {
|
||||
"parsingStrategy": {
|
||||
"parsingType": "fast"
|
||||
},
|
||||
"chunkStrategy": {
|
||||
"chunkType": "default"
|
||||
},
|
||||
"indexStrategy": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "164528",
|
||||
"type": "27",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1100,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"title": "知识库写入_1",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
|
||||
"description": "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
|
||||
"mainColor": "#FF811A",
|
||||
"subTitle": "知识库写入"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "documentId"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "fileName"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"name": "fileUrl"
|
||||
}
|
||||
],
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "knowledge",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"assistType": 1,
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "http://coze.fanlv.fun:8889/opencoze/tos-cn-i-v4nquku3lp/b84062e4-d79e-4b3d-bf42-eaab754250df.txt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250910%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250910T075358Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=346934ac110db0153afc1759c372e856afc9c5298c8c72e1210bba0c979f84cb&x-wf-file_name=%E5%8C%97%E4%BA%AC%E6%97%85%E6%B8%B8%E6%99%AF%E7%82%B9.txt",
|
||||
"rawMeta": {
|
||||
"fileName": "北京旅游景点.txt",
|
||||
"type": 8
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"datasetParam": [
|
||||
{
|
||||
"name": "datasetList",
|
||||
"input": {
|
||||
"type": "list",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": ["7548363002819379200"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"strategyParam": {
|
||||
"parsingStrategy": {
|
||||
"parsingType": "fast"
|
||||
},
|
||||
"chunkStrategy": {
|
||||
"chunkType": "default"
|
||||
},
|
||||
"indexStrategy": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "129362"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "191034",
|
||||
"targetNodeID": "900001"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "164528",
|
||||
"targetNodeID": "191034"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "129362",
|
||||
"targetNodeID": "164528"
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -85,7 +85,6 @@ func init() {
|
||||
_ = compose.RegisterSerializableType[*vo.TypeInfo]("type_info")
|
||||
_ = compose.RegisterSerializableType[vo.DataType]("data_type")
|
||||
_ = compose.RegisterSerializableType[vo.FileSubType]("file_sub_type")
|
||||
_ = compose.RegisterSerializableType[*workflowModel.FileInfo]("file_info")
|
||||
}
|
||||
|
||||
func (s *State) GetNodeCtx(key vo.NodeKey) (*execute.Context, bool, error) {
|
||||
|
||||
@ -148,7 +148,7 @@ func (ch *ConversationHistory) Invoke(ctx context.Context, input map[string]any)
|
||||
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
BizID: *appID,
|
||||
AppID: *appID,
|
||||
Rounds: rounds,
|
||||
InitRunID: initRunID,
|
||||
SectionID: sectionID,
|
||||
|
||||
@ -109,7 +109,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
|
||||
if existed {
|
||||
cID, _, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
BizID: ptr.From(appID),
|
||||
AppID: ptr.From(appID),
|
||||
TemplateID: template.TemplateID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
@ -125,7 +125,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
}
|
||||
|
||||
cID, _, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
BizID: ptr.From(appID),
|
||||
AppID: ptr.From(appID),
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Name: conversationName,
|
||||
|
||||
@ -98,7 +98,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env,
|
||||
var conversationID int64
|
||||
if isExist {
|
||||
cID, _, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
BizID: ptr.From(appID),
|
||||
AppID: ptr.From(appID),
|
||||
TemplateID: template.TemplateID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
@ -150,7 +150,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
|
||||
var conversationID int64
|
||||
var err error
|
||||
var bizID int64
|
||||
var resolvedAppID int64
|
||||
if appID == nil {
|
||||
if conversationName != "Default" {
|
||||
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
|
||||
@ -167,13 +167,13 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
}, nil
|
||||
}
|
||||
conversationID = *execCtx.ExeCfg.ConversationID
|
||||
bizID = *agentID
|
||||
resolvedAppID = *agentID
|
||||
} else {
|
||||
conversationID, err = c.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bizID = *appID
|
||||
resolvedAppID = *appID
|
||||
}
|
||||
|
||||
if conversationID == 0 {
|
||||
@ -209,7 +209,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
if role == "user" {
|
||||
// For user messages, always create a new run and store the ID in the context.
|
||||
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: bizID,
|
||||
AgentID: resolvedAppID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -244,7 +244,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
BizID: bizID,
|
||||
AppID: resolvedAppID,
|
||||
Rounds: 1,
|
||||
})
|
||||
if err != nil {
|
||||
@ -254,7 +254,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
runID = runIDs[0]
|
||||
} else {
|
||||
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
|
||||
AgentID: bizID,
|
||||
AgentID: resolvedAppID,
|
||||
ConversationID: conversationID,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
ConnectorID: connectorID,
|
||||
@ -273,7 +273,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
Content: content,
|
||||
ContentType: model.ContentType("text"),
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
AgentID: bizID,
|
||||
AgentID: resolvedAppID,
|
||||
RunID: runID,
|
||||
SectionID: sectionID,
|
||||
}
|
||||
|
||||
@ -115,7 +115,7 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
|
||||
|
||||
var conversationID int64
|
||||
var err error
|
||||
var bizID int64
|
||||
var resolvedAppID int64
|
||||
if appID == nil {
|
||||
if conversationName != "Default" {
|
||||
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
|
||||
@ -129,18 +129,18 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
|
||||
}, nil
|
||||
}
|
||||
conversationID = *execCtx.ExeCfg.ConversationID
|
||||
bizID = *agentID
|
||||
resolvedAppID = *agentID
|
||||
} else {
|
||||
conversationID, err = m.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bizID = *appID
|
||||
resolvedAppID = *appID
|
||||
}
|
||||
|
||||
req := &crossmessage.MessageListRequest{
|
||||
UserID: userID,
|
||||
BizID: bizID,
|
||||
AppID: resolvedAppID,
|
||||
ConversationID: conversationID,
|
||||
}
|
||||
|
||||
|
||||
@ -20,9 +20,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -130,23 +127,9 @@ func ConvertInputs(ctx context.Context, in map[string]any, tInfo map[string]*vo.
|
||||
}
|
||||
|
||||
type convertOptions struct {
|
||||
skipUnknownFields bool
|
||||
failFast bool
|
||||
skipRequireCheck bool
|
||||
collectFileFields map[string]*workflowModel.FileInfo
|
||||
notNeedTrimQueryFileName bool
|
||||
}
|
||||
|
||||
func WithCollectFileFields(fs map[string]*workflowModel.FileInfo) ConvertOption {
|
||||
return func(o *convertOptions) {
|
||||
o.collectFileFields = fs
|
||||
}
|
||||
}
|
||||
|
||||
func WithNotNeedTrimQueryFileName(b bool) ConvertOption {
|
||||
return func(o *convertOptions) {
|
||||
o.notNeedTrimQueryFileName = b
|
||||
}
|
||||
skipUnknownFields bool
|
||||
failFast bool
|
||||
skipRequireCheck bool
|
||||
}
|
||||
|
||||
type ConvertOption func(*convertOptions)
|
||||
@ -178,23 +161,6 @@ func Convert(ctx context.Context, in any, path string, t *vo.TypeInfo, opts ...C
|
||||
return convert(ctx, in, path, t, options)
|
||||
}
|
||||
|
||||
func adaptorFileURL(in string) (string, *workflowModel.FileInfo, error) {
|
||||
u, err := url.Parse(in)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
query := u.Query()
|
||||
fileName := query.Get("x-wf-file_name")
|
||||
fileInfo := &workflowModel.FileInfo{
|
||||
FileName: fileName,
|
||||
FileExtension: filepath.Ext(fileName),
|
||||
}
|
||||
query.Del("x-wf-file_name")
|
||||
u.RawQuery = query.Encode()
|
||||
fileInfo.FileURL = u.String()
|
||||
return u.String(), fileInfo, nil
|
||||
}
|
||||
|
||||
func convert(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
|
||||
any, *ConversionWarnings, error) {
|
||||
if in == nil { // nil is valid for ALL types
|
||||
@ -202,28 +168,8 @@ func convert(ctx context.Context, in any, path string, t *vo.TypeInfo, options *
|
||||
}
|
||||
|
||||
switch t.Type {
|
||||
case vo.DataTypeString, vo.DataTypeTime:
|
||||
case vo.DataTypeString, vo.DataTypeFile, vo.DataTypeTime:
|
||||
return convertToString(ctx, in, path, options)
|
||||
case vo.DataTypeFile:
|
||||
ret, warns, err := convertToString(ctx, in, path, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if warns != nil {
|
||||
return ret, warns, nil
|
||||
}
|
||||
|
||||
fileURL, fileInfo, err := adaptorFileURL(ret.(string))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if options.collectFileFields != nil {
|
||||
options.collectFileFields[fileInfo.FileURL] = fileInfo
|
||||
}
|
||||
if options.notNeedTrimQueryFileName {
|
||||
return ret, nil, nil
|
||||
}
|
||||
return fileURL, nil, nil
|
||||
case vo.DataTypeInteger:
|
||||
return convertToInt64(ctx, in, path, options)
|
||||
case vo.DataTypeNumber:
|
||||
|
||||
@ -82,9 +82,7 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
|
||||
return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
|
||||
}
|
||||
|
||||
if content.Value.Content != nil {
|
||||
c.Template = content.Value.Content.(string)
|
||||
}
|
||||
c.Template = content.Value.Content.(string)
|
||||
}
|
||||
|
||||
if n.Data.Inputs.TerminatePlan == nil {
|
||||
|
||||
@ -20,9 +20,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
@ -32,7 +31,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
@ -127,7 +125,7 @@ func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]
|
||||
return nil, errors.New("knowledge is required")
|
||||
}
|
||||
|
||||
fileName, ext, err := parseToFileNameAndFileExtension(ctx, fileURL)
|
||||
fileName, ext, err := parseToFileNameAndFileExtension(fileURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -155,25 +153,23 @@ func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parseToFileNameAndFileExtension(ctx context.Context, fileURL string) (string, parser.FileExtension, error) {
|
||||
inputFileFields := execute.GetExeCtx(ctx).ExeCfg.InputFileFields
|
||||
fileInfo, ok := inputFileFields[fileURL]
|
||||
if !ok {
|
||||
u, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
fileExt := filepath.Ext(strings.ToLower(strings.TrimPrefix(u.Path, ".")))
|
||||
ext, support := parser.ValidateFileExtension(fileExt)
|
||||
if !support {
|
||||
return "", "", fmt.Errorf("unsupported file type: %s", fileExt)
|
||||
}
|
||||
return u.Path, ext, nil
|
||||
}
|
||||
ext, support := parser.ValidateFileExtension(strings.ToLower(strings.TrimPrefix(fileInfo.FileExtension, ".")))
|
||||
if !support {
|
||||
return "", "", fmt.Errorf("unsupported file type: %s", fileInfo.FileExtension)
|
||||
}
|
||||
return fileInfo.FileName, ext, nil
|
||||
func parseToFileNameAndFileExtension(fileURL string) (string, parser.FileExtension, error) {
|
||||
|
||||
u, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
fileName := u.Query().Get("x-wf-file_name")
|
||||
if len(fileName) == 0 {
|
||||
return "", "", errors.New("file name is required")
|
||||
}
|
||||
|
||||
fileExt := strings.ToLower(strings.TrimPrefix(filepath.Ext(fileName), "."))
|
||||
|
||||
ext, support := parser.ValidateFileExtension(fileExt)
|
||||
if !support {
|
||||
return "", "", fmt.Errorf("unsupported file type: %s", fileExt)
|
||||
}
|
||||
return fileName, ext, nil
|
||||
}
|
||||
|
||||
@ -100,9 +100,9 @@ const (
|
||||
ReasoningOutputKey = "reasoning_content"
|
||||
)
|
||||
|
||||
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
|
||||
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
|
||||
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
|
||||
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
|
||||
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
|
||||
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
|
||||
例如:
|
||||
如果内容为<img src="https://example.com/image.jpg">一只小猫,你的输出应为:。
|
||||
如果内容为<img src="https://example.com/image1.jpg">一只小猫 和 <img src="https://example.com/image2.jpg">一只小狗 和 <img src="https://example.com/image3.jpg">一只小牛,你的输出应为: 和  和 
|
||||
@ -290,7 +290,7 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
|
||||
c.AssociateStartNodeUserInputFields = make(map[string]struct{})
|
||||
for _, info := range ns.InputSources {
|
||||
if len(info.Path) == 1 && info.Source.Ref != nil && info.Source.Ref.FromNodeKey == entity.EntryNodeKey {
|
||||
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField(vo.UserInputKey)) {
|
||||
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField("USER_INPUT")) {
|
||||
c.AssociateStartNodeUserInputFields[info.Path[0]] = struct{}{}
|
||||
}
|
||||
}
|
||||
@ -1115,7 +1115,7 @@ func (l *LLM) handleInterrupt(ctx context.Context, err error, resumingEvent *ent
|
||||
NodeKey: c.NodeKey,
|
||||
NodeType: entity.NodeTypeLLM,
|
||||
NodeTitle: c.NodeName,
|
||||
NodeIcon: entity.NodeMetaByNodeType(entity.NodeTypeLLM).IconURI,
|
||||
NodeIcon: entity.NodeMetaByNodeType(entity.NodeTypeLLM).IconURL,
|
||||
EventType: entity.InterruptEventLLM,
|
||||
}
|
||||
|
||||
|
||||
@ -192,3 +192,8 @@ type StreamGenerator interface {
|
||||
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
|
||||
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
|
||||
}
|
||||
|
||||
type ChatHistoryAware interface {
|
||||
ChatHistoryEnabled() bool
|
||||
ChatHistoryRounds() int64
|
||||
}
|
||||
|
||||
@ -343,14 +343,14 @@ const (
|
||||
你是一个参数提取 agent,你的工作是从用户的回答中提取出多个字段的值,每个字段遵循以下规则
|
||||
# 字段说明
|
||||
%s
|
||||
## 输出要求
|
||||
- 严格以 json 格式返回答案。
|
||||
- 严格确保答案采用有效的 JSON 格式。
|
||||
- 按照字段说明提取出字段的值,将已经提取到的字段放在 fields 字段
|
||||
- 对于未提取到的<必填字段>生成一个新的追问问题question
|
||||
- 确保在追问问题中只包含所有未提取的<必填字段>
|
||||
- 不要重复问之前问过的问题
|
||||
- 问题的语种请和用户的输入保持一致,如英文、中文等
|
||||
## 输出要求
|
||||
- 严格以 json 格式返回答案。
|
||||
- 严格确保答案采用有效的 JSON 格式。
|
||||
- 按照字段说明提取出字段的值,将已经提取到的字段放在 fields 字段
|
||||
- 对于未提取到的<必填字段>生成一个新的追问问题question
|
||||
- 确保在追问问题中只包含所有未提取的<必填字段>
|
||||
- 不要重复问之前问过的问题
|
||||
- 问题的语种请和用户的输入保持一致,如英文、中文等
|
||||
- 输出按照下面结构体格式返回,包含提取到的字段或者追问的问题
|
||||
- 不要回复和提取无关的问题
|
||||
type Output struct {
|
||||
@ -359,7 +359,7 @@ question string // Follow-up question for the next round
|
||||
}`
|
||||
extractUserPromptSuffix = `
|
||||
- 严格以 json 格式返回答案。
|
||||
- 严格确保答案采用有效的 JSON 格式。
|
||||
- 严格确保答案采用有效的 JSON 格式。
|
||||
- - 必填字段没有获取全则继续追问
|
||||
- 必填字段: %s
|
||||
%s
|
||||
@ -728,7 +728,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
|
||||
NodeKey: q.nodeKey,
|
||||
NodeType: entity.NodeTypeQuestionAnswer,
|
||||
NodeTitle: q.nodeMeta.Name,
|
||||
NodeIcon: q.nodeMeta.IconURI,
|
||||
NodeIcon: q.nodeMeta.IconURL,
|
||||
InterruptData: interruptData,
|
||||
EventType: entity.InterruptEventQuestion,
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ func (i *InputReceiver) Invoke(ctx context.Context, _ map[string]any) (map[strin
|
||||
NodeKey: i.nodeKey,
|
||||
NodeType: entity.NodeTypeInputReceiver,
|
||||
NodeTitle: i.nodeMeta.Name,
|
||||
NodeIcon: i.nodeMeta.IconURI,
|
||||
NodeIcon: i.nodeMeta.IconURL,
|
||||
InterruptData: i.interruptData,
|
||||
EventType: entity.InterruptEventInput,
|
||||
})
|
||||
|
||||
@ -432,7 +432,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
|
||||
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
|
||||
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
|
||||
appDynamicConversationDraft.AppID.Eq(meta.BizID),
|
||||
appDynamicConversationDraft.AppID.Eq(meta.AppID),
|
||||
appDynamicConversationDraft.ConnectorID.Eq(meta.ConnectorID),
|
||||
appDynamicConversationDraft.UserID.Eq(meta.UserID),
|
||||
appDynamicConversationDraft.Name.Eq(meta.Name),
|
||||
@ -452,7 +452,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -464,7 +464,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
|
||||
err = r.query.AppDynamicConversationDraft.WithContext(ctx).Create(&model.AppDynamicConversationDraft{
|
||||
ID: id,
|
||||
AppID: meta.BizID,
|
||||
AppID: meta.AppID,
|
||||
Name: meta.Name,
|
||||
UserID: meta.UserID,
|
||||
ConnectorID: meta.ConnectorID,
|
||||
@ -479,7 +479,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
} else if env == vo.Online {
|
||||
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
|
||||
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
|
||||
appDynamicConversationOnline.AppID.Eq(meta.BizID),
|
||||
appDynamicConversationOnline.AppID.Eq(meta.AppID),
|
||||
appDynamicConversationOnline.ConnectorID.Eq(meta.ConnectorID),
|
||||
appDynamicConversationOnline.UserID.Eq(meta.UserID),
|
||||
appDynamicConversationOnline.Name.Eq(meta.Name),
|
||||
@ -498,7 +498,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -509,7 +509,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
|
||||
err = r.query.AppDynamicConversationOnline.WithContext(ctx).Create(&model.AppDynamicConversationOnline{
|
||||
ID: id,
|
||||
AppID: meta.BizID,
|
||||
AppID: meta.AppID,
|
||||
Name: meta.Name,
|
||||
UserID: meta.UserID,
|
||||
ConnectorID: meta.ConnectorID,
|
||||
@ -586,7 +586,7 @@ func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context,
|
||||
return cs[0].ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -627,7 +627,7 @@ func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context
|
||||
return cs[0].ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
|
||||
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, 0, false, err
|
||||
}
|
||||
@ -841,7 +841,7 @@ func (r *RepositoryImpl) CopyTemplateConversationByAppID(ctx context.Context, ap
|
||||
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
|
||||
func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
if env == vo.Draft {
|
||||
appStaticConversationDraft := r.query.AppStaticConversationDraft
|
||||
ret, err := appStaticConversationDraft.WithContext(ctx).Where(
|
||||
@ -857,7 +857,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
appConversationTemplateDraft := r.query.AppConversationTemplateDraft
|
||||
template, err := appConversationTemplateDraft.WithContext(ctx).Where(
|
||||
appConversationTemplateDraft.TemplateID.Eq(ret.TemplateID),
|
||||
appConversationTemplateDraft.AppID.Eq(bizID),
|
||||
appConversationTemplateDraft.AppID.Eq(appID),
|
||||
).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@ -881,7 +881,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
appConversationTemplateOnline := r.query.AppConversationTemplateOnline
|
||||
template, err := appConversationTemplateOnline.WithContext(ctx).Where(
|
||||
appConversationTemplateOnline.TemplateID.Eq(ret.TemplateID),
|
||||
appConversationTemplateOnline.AppID.Eq(bizID),
|
||||
appConversationTemplateOnline.AppID.Eq(appID),
|
||||
).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@ -894,11 +894,11 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
return "", false, fmt.Errorf("unknown env %v", env)
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
|
||||
func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
|
||||
if env == vo.Draft {
|
||||
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
|
||||
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
|
||||
appDynamicConversationDraft.AppID.Eq(bizID),
|
||||
appDynamicConversationDraft.AppID.Eq(appID),
|
||||
appDynamicConversationDraft.ConnectorID.Eq(connectorID),
|
||||
appDynamicConversationDraft.ConversationID.Eq(conversationID),
|
||||
).First()
|
||||
@ -918,7 +918,7 @@ func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.
|
||||
} else if env == vo.Online {
|
||||
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
|
||||
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
|
||||
appDynamicConversationOnline.AppID.Eq(bizID),
|
||||
appDynamicConversationOnline.AppID.Eq(appID),
|
||||
appDynamicConversationOnline.ConnectorID.Eq(connectorID),
|
||||
appDynamicConversationOnline.ConversationID.Eq(conversationID),
|
||||
).First()
|
||||
|
||||
@ -129,8 +129,3 @@ func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
|
||||
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
|
||||
s.OutputSources = append(s.OutputSources, info...)
|
||||
}
|
||||
|
||||
type ChatHistoryAware interface {
|
||||
ChatHistoryEnabled() bool
|
||||
ChatHistoryRounds() int64
|
||||
}
|
||||
|
||||
@ -17,19 +17,13 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type WorkflowSchema struct {
|
||||
@ -44,7 +38,6 @@ type WorkflowSchema struct {
|
||||
compositeNodes []*CompositeNode // won't serialize this
|
||||
requireCheckPoint bool // won't serialize this
|
||||
requireStreaming bool
|
||||
historyRounds int64
|
||||
|
||||
once sync.Once
|
||||
}
|
||||
@ -76,22 +69,15 @@ func (w *WorkflowSchema) Init() {
|
||||
|
||||
w.doGetCompositeNodes()
|
||||
|
||||
historyRounds := int64(0)
|
||||
for _, node := range w.Nodes {
|
||||
if node.Type == entity.NodeTypeSubWorkflow {
|
||||
node.SubWorkflowSchema.Init()
|
||||
historyRounds = max(historyRounds, node.SubWorkflowSchema.HistoryRounds())
|
||||
if node.SubWorkflowSchema.requireCheckPoint {
|
||||
w.requireCheckPoint = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
chatHistoryAware, ok := node.Configs.(ChatHistoryAware)
|
||||
if ok && chatHistoryAware.ChatHistoryEnabled() {
|
||||
historyRounds = max(historyRounds, chatHistoryAware.ChatHistoryRounds())
|
||||
}
|
||||
|
||||
if rc, ok := node.Configs.(RequireCheckpoint); ok {
|
||||
if rc.RequireCheckpoint() {
|
||||
w.requireCheckPoint = true
|
||||
@ -100,7 +86,6 @@ func (w *WorkflowSchema) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
w.historyRounds = historyRounds
|
||||
w.requireStreaming = w.doRequireStreaming()
|
||||
})
|
||||
}
|
||||
@ -137,12 +122,6 @@ func (w *WorkflowSchema) RequireStreaming() bool {
|
||||
return w.requireStreaming
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) HistoryRounds() int64 { return w.historyRounds }
|
||||
|
||||
func (w *WorkflowSchema) SetHistoryRounds(historyRounds int64) {
|
||||
w.historyRounds = historyRounds
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
|
||||
if w.Hierarchy == nil {
|
||||
return nil
|
||||
@ -328,65 +307,3 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetAllNodesInputFileFields(ctx context.Context) []*workflowModel.FileInfo {
|
||||
|
||||
adaptorURL := func(s string) (string, error) {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := u.Query()
|
||||
query.Del("x-wf-file_name")
|
||||
u.RawQuery = query.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
result := make([]*workflowModel.FileInfo, 0)
|
||||
for _, node := range w.Nodes {
|
||||
for _, source := range node.InputSources {
|
||||
if source.Source.Val != nil && source.Source.FileExtra != nil {
|
||||
fileExtra := source.Source.FileExtra
|
||||
if fileExtra.FileName != nil {
|
||||
fileURL, err := adaptorURL(source.Source.Val.(string))
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "failed to parse adaptorURL for node %v: %v", node.Key, err)
|
||||
continue
|
||||
}
|
||||
result = append(result, &workflowModel.FileInfo{
|
||||
FileName: *fileExtra.FileName,
|
||||
FileURL: fileURL,
|
||||
FileExtension: filepath.Ext(strings.TrimSpace(*fileExtra.FileName)),
|
||||
})
|
||||
source.Source.Val = fileURL
|
||||
|
||||
}
|
||||
if fileExtra.FileNames != nil {
|
||||
vals := source.Source.Val.([]any)
|
||||
for idx, fileName := range fileExtra.FileNames {
|
||||
fileURL := vals[idx].(string)
|
||||
fileURL, err := adaptorURL(fileURL)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "failed to parse adaptorURL for node %v: %v", node.Key, err)
|
||||
continue
|
||||
}
|
||||
result = append(result, &workflowModel.FileInfo{
|
||||
FileName: fileName,
|
||||
FileURL: fileURL,
|
||||
FileExtension: filepath.Ext(strings.TrimSpace(fileName)),
|
||||
})
|
||||
vals[idx] = fileURL
|
||||
}
|
||||
source.Source.Val = vals
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if node.SubWorkflowSchema != nil {
|
||||
result = append(result, node.SubWorkflowSchema.GetAllNodesInputFileFields(ctx)...)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@ -248,7 +248,7 @@ func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Con
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if v.Name == vo.ConversationNameKey && v.DefaultValue == name {
|
||||
if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
@ -296,7 +296,7 @@ func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Name == vo.ConversationNameKey {
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
v.DefaultValue = conversionName
|
||||
}
|
||||
startNode.Data.Outputs[idx] = v
|
||||
@ -351,18 +351,18 @@ func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo
|
||||
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
|
||||
AppID: ptr.Of(bizID),
|
||||
AppID: ptr.Of(appID),
|
||||
Name: ptr.Of(conversationName),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, bizID int64, userID, connectorID int64) (*conventity.Conversation, error) {
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) {
|
||||
return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
|
||||
AgentID: bizID,
|
||||
AgentID: appID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
@ -371,7 +371,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
|
||||
|
||||
if existed {
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
BizID: bizID,
|
||||
AppID: appID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
TemplateID: t.TemplateID,
|
||||
@ -383,7 +383,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
|
||||
}
|
||||
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
BizID: bizID,
|
||||
AppID: appID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
Name: conversationName,
|
||||
@ -465,8 +465,8 @@ func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env
|
||||
return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
|
||||
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, bizID, connectorID, conversationID)
|
||||
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
@ -474,7 +474,7 @@ func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.E
|
||||
return sc, true, nil
|
||||
}
|
||||
|
||||
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, bizID, connectorID, conversationID)
|
||||
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
@ -22,8 +22,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
@ -86,9 +84,6 @@ func (i *impl) SyncExecute(ctx context.Context, config workflowModel.ExecuteConf
|
||||
return nil, "", fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
config.InputFileFields = slices.ToMap(workflowSC.GetAllNodesInputFileFields(ctx), func(e *workflowModel.FileInfo) (string, *workflowModel.FileInfo) {
|
||||
return e.FileURL, e
|
||||
})
|
||||
var wfOpts []compose.WorkflowOption
|
||||
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
@ -105,8 +100,6 @@ func (i *impl) SyncExecute(ctx context.Context, config workflowModel.ExecuteConf
|
||||
}
|
||||
|
||||
var cOpts []nodes.ConvertOption
|
||||
inputFileFields := make(map[string]*workflowModel.FileInfo)
|
||||
cOpts = append(cOpts, nodes.WithCollectFileFields(inputFileFields), nodes.WithNotNeedTrimQueryFileName(true))
|
||||
if config.InputFailFast {
|
||||
cOpts = append(cOpts, nodes.FailFast())
|
||||
}
|
||||
@ -118,10 +111,6 @@ func (i *impl) SyncExecute(ctx context.Context, config workflowModel.ExecuteConf
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
for k, v := range inputFileFields {
|
||||
config.InputFileFields[k] = v
|
||||
}
|
||||
|
||||
inStr, err := sonic.MarshalString(input)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@ -242,10 +231,6 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
config.InputFileFields = slices.ToMap(workflowSC.GetAllNodesInputFileFields(ctx), func(e *workflowModel.FileInfo) (string, *workflowModel.FileInfo) {
|
||||
return e.FileURL, e
|
||||
})
|
||||
|
||||
var wfOpts []compose.WorkflowOption
|
||||
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
@ -264,8 +249,6 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
config.CommitID = wfEntity.CommitID
|
||||
|
||||
var cOpts []nodes.ConvertOption
|
||||
inputFileFields := make(map[string]*workflowModel.FileInfo)
|
||||
cOpts = append(cOpts, nodes.WithCollectFileFields(inputFileFields), nodes.WithNotNeedTrimQueryFileName(true))
|
||||
if config.InputFailFast {
|
||||
cOpts = append(cOpts, nodes.FailFast())
|
||||
}
|
||||
@ -277,10 +260,6 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
for k, v := range inputFileFields {
|
||||
config.InputFileFields[k] = v
|
||||
}
|
||||
|
||||
inStr, err := sonic.MarshalString(input)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -303,50 +282,6 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
return executeID, nil
|
||||
}
|
||||
|
||||
func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error {
|
||||
if historyRounds <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if shouldFetchConversationByName {
|
||||
var cID, sID, bizID int64
|
||||
var err error
|
||||
if config.AppID != nil {
|
||||
bizID = *config.AppID
|
||||
} else if config.AgentID != nil {
|
||||
bizID = *config.AgentID
|
||||
}
|
||||
for k, v := range input {
|
||||
if k == vo.ConversationNameKey {
|
||||
cName, ok := v.(string)
|
||||
if !ok {
|
||||
return errors.New("CONVERSATION_NAME must be string")
|
||||
}
|
||||
cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.ConversationID = ptr.Of(cID)
|
||||
config.SectionID = ptr.Of(sID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) {
|
||||
var (
|
||||
err error
|
||||
@ -373,6 +308,30 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
|
||||
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@ -383,27 +342,12 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds = workflowSC.HistoryRounds()
|
||||
}
|
||||
if historyRounds > 0 {
|
||||
if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
config.InputFileFields = slices.ToMap(workflowSC.GetAllNodesInputFileFields(ctx), func(e *workflowModel.FileInfo) (string, *workflowModel.FileInfo) {
|
||||
return e.FileURL, e
|
||||
})
|
||||
|
||||
wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID)))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create workflow: %w", err)
|
||||
}
|
||||
|
||||
var cOpts []nodes.ConvertOption
|
||||
inputFileFields := make(map[string]*workflowModel.FileInfo)
|
||||
cOpts = append(cOpts, nodes.WithCollectFileFields(inputFileFields), nodes.WithNotNeedTrimQueryFileName(true))
|
||||
if config.InputFailFast {
|
||||
cOpts = append(cOpts, nodes.FailFast())
|
||||
}
|
||||
@ -414,9 +358,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
for k, v := range inputFileFields {
|
||||
config.InputFileFields[k] = v
|
||||
}
|
||||
|
||||
if wfEntity.AppID != nil && config.AppID == nil {
|
||||
config.AppID = wfEntity.AppID
|
||||
@ -476,6 +417,29 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@ -486,23 +450,7 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds = workflowSC.HistoryRounds()
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
config.InputFileFields = slices.ToMap(workflowSC.GetAllNodesInputFileFields(ctx), func(e *workflowModel.FileInfo) (string, *workflowModel.FileInfo) {
|
||||
return e.FileURL, e
|
||||
})
|
||||
|
||||
var wfOpts []compose.WorkflowOption
|
||||
|
||||
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
wfOpts = append(wfOpts, compose.WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
@ -520,8 +468,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
config.CommitID = wfEntity.CommitID
|
||||
|
||||
var cOpts []nodes.ConvertOption
|
||||
inputFileFields := make(map[string]*workflowModel.FileInfo)
|
||||
cOpts = append(cOpts, nodes.WithCollectFileFields(inputFileFields), nodes.WithNotNeedTrimQueryFileName(true))
|
||||
if config.InputFailFast {
|
||||
cOpts = append(cOpts, nodes.FailFast())
|
||||
}
|
||||
@ -532,9 +478,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
for k, v := range inputFileFields {
|
||||
config.InputFileFields[k] = v
|
||||
}
|
||||
|
||||
inStr, err := sonic.MarshalString(input)
|
||||
if err != nil {
|
||||
@ -1054,6 +997,20 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxHistoryRounds int64 = 30
|
||||
|
||||
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return min(maxRounds, maxHistoryRounds), nil
|
||||
}
|
||||
|
||||
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
|
||||
convID := config.ConversationID
|
||||
agentID := config.AgentID
|
||||
@ -1070,11 +1027,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
var bizID int64
|
||||
var resolvedAppID int64
|
||||
if appID != nil {
|
||||
bizID = *appID
|
||||
resolvedAppID = *appID
|
||||
} else if agentID != nil {
|
||||
bizID = *agentID
|
||||
resolvedAppID = *agentID
|
||||
} else {
|
||||
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
|
||||
return nil, nil, nil
|
||||
@ -1082,7 +1039,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
|
||||
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: *convID,
|
||||
BizID: bizID,
|
||||
AppID: resolvedAppID,
|
||||
UserID: userID,
|
||||
Rounds: historyRounds + 1,
|
||||
SectionID: *sectionID,
|
||||
@ -1091,7 +1048,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
|
||||
return nil, nil, err
|
||||
return nil, nil, nil
|
||||
}
|
||||
if len(runIds) <= 1 {
|
||||
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
|
||||
@ -1104,7 +1061,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
|
||||
return nil, nil, err
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
return response.Messages, response.SchemaMessages, nil
|
||||
|
||||
@ -1,286 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
messagemock "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
mock_workflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestImpl_handleHistory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
defer ctrl.Finish()
|
||||
|
||||
// Setup for cross-domain service mock
|
||||
mockMessage := messagemock.NewMockMessage(ctrl)
|
||||
crossmessage.SetDefaultSVC(mockMessage)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository)
|
||||
config *workflowModel.ExecuteConfig
|
||||
input map[string]any
|
||||
historyRounds int64
|
||||
shouldFetch bool
|
||||
expectErr bool
|
||||
expectedHistory []*crossmessage.WfMessage
|
||||
expectedSchemaHistory []*schema.Message
|
||||
}{
|
||||
{
|
||||
name: "historyRounds is zero",
|
||||
historyRounds: 0,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "shouldFetch is false",
|
||||
historyRounds: 5,
|
||||
shouldFetch: false,
|
||||
config: &workflowModel.ExecuteConfig{
|
||||
AppID: ptr.Of(int64(1)),
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2}, nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
|
||||
Messages: []*crossmessage.WfMessage{{ID: 1}},
|
||||
SchemaMessages: []*schema.Message{{
|
||||
Role: schema.User,
|
||||
Content: "123",
|
||||
}},
|
||||
}, nil).AnyTimes()
|
||||
},
|
||||
expectErr: false,
|
||||
expectedHistory: []*crossmessage.WfMessage{{ID: 1}},
|
||||
expectedSchemaHistory: []*schema.Message{{
|
||||
Role: schema.User,
|
||||
Content: "123",
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "fetch conversation by name - conversation exists",
|
||||
historyRounds: 3,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
|
||||
input: map[string]any{"CONVERSATION_NAME": "test-conv"},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "test-conv").Return(int64(200), int64(201), nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{3, 4}, nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
|
||||
Messages: []*crossmessage.WfMessage{{ID: 2}},
|
||||
SchemaMessages: []*schema.Message{{
|
||||
Role: schema.Assistant,
|
||||
Content: "123",
|
||||
}},
|
||||
}, nil).AnyTimes()
|
||||
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
|
||||
TemplateID: int64(202),
|
||||
SpaceID: int64(203),
|
||||
AppID: int64(204),
|
||||
}, true, nil).AnyTimes()
|
||||
repo.EXPECT().GetOrCreateStaticConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
|
||||
},
|
||||
expectErr: false,
|
||||
expectedHistory: []*crossmessage.WfMessage{{ID: 2}},
|
||||
expectedSchemaHistory: []*schema.Message{{
|
||||
Role: schema.Assistant,
|
||||
Content: "123",
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "fetch conversation by name - conversation not exists",
|
||||
historyRounds: 3,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AgentID: ptr.Of(int64(2))},
|
||||
input: map[string]any{"CONVERSATION_NAME": "new-conv"},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "new-conv").Return(int64(300), int64(301), nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{5, 6}, nil).AnyTimes()
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
|
||||
Messages: []*crossmessage.WfMessage{{ID: 3}},
|
||||
}, nil).AnyTimes()
|
||||
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
|
||||
TemplateID: int64(202),
|
||||
SpaceID: int64(203),
|
||||
AppID: int64(204),
|
||||
}, false, nil).AnyTimes()
|
||||
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
|
||||
},
|
||||
expectErr: false,
|
||||
expectedHistory: []*crossmessage.WfMessage{{ID: 3}},
|
||||
},
|
||||
{
|
||||
name: "input with wrong type for conversation name",
|
||||
historyRounds: 5,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
|
||||
input: map[string]any{"CONVERSATION_NAME": 12345},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "GetOrCreateConversation returns error",
|
||||
historyRounds: 5,
|
||||
shouldFetch: true,
|
||||
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
|
||||
input: map[string]any{"CONVERSATION_NAME": "fail-conv"},
|
||||
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
|
||||
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "fail-conv").Return(int64(0), int64(0), errors.New("db error")).AnyTimes()
|
||||
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
|
||||
TemplateID: int64(202),
|
||||
SpaceID: int64(203),
|
||||
AppID: int64(204),
|
||||
}, false, nil).AnyTimes()
|
||||
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, errors.New("db error")).AnyTimes()
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockService := mock_workflow.NewMockService(ctrl)
|
||||
mockRepo := mock_workflow.NewMockRepository(ctrl)
|
||||
testImpl := &impl{repo: mockRepo, conversationImpl: &conversationImpl{repo: mockRepo}}
|
||||
|
||||
tt.setupMock(mockService, mockMessage, mockRepo)
|
||||
|
||||
err := testImpl.handleHistory(ctx, tt.config, tt.input, tt.historyRounds, tt.shouldFetch)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedHistory != nil {
|
||||
assert.Equal(t, tt.expectedHistory, tt.config.ConversationHistory)
|
||||
} else if tt.historyRounds == 0 {
|
||||
assert.Nil(t, tt.config.ConversationHistory)
|
||||
} else if tt.expectedSchemaHistory != nil {
|
||||
assert.Equal(t, tt.expectedSchemaHistory, tt.config.ConversationHistorySchemaMessages)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImpl_prefetchChatHistory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockMessage := messagemock.NewMockMessage(ctrl)
|
||||
crossmessage.SetDefaultSVC(mockMessage)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(msgSvc *messagemock.MockMessage)
|
||||
config workflowModel.ExecuteConfig
|
||||
historyRounds int64
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "SectionID is nil",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
AppID: ptr.Of(int64(1)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "ConversationID is nil",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
AppID: ptr.Of(int64(1)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "AppID and AgentID are both nil",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "GetLatestRunIDs returns error",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
AppID: ptr.Of(int64(1)),
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "GetMessagesByRunIDs returns error",
|
||||
config: workflowModel.ExecuteConfig{
|
||||
AppID: ptr.Of(int64(1)),
|
||||
ConversationID: ptr.Of(int64(100)),
|
||||
SectionID: ptr.Of(int64(101)),
|
||||
},
|
||||
historyRounds: 5,
|
||||
setupMock: func(msgSvc *messagemock.MockMessage) {
|
||||
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2, 3}, nil)
|
||||
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testImpl := &impl{}
|
||||
tt.setupMock(mockMessage)
|
||||
|
||||
_, _, err := testImpl.prefetchChatHistory(ctx, tt.config, tt.historyRounds)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -39,6 +39,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
@ -521,7 +522,7 @@ func isEnableChatHistory(s *schema.NodeSchema) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
chatHistoryAware, ok := s.Configs.(schema.ChatHistoryAware)
|
||||
chatHistoryAware, ok := s.Configs.(nodes.ChatHistoryAware)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@ -2170,15 +2171,15 @@ func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error {
|
||||
vMap[v.Name] = true
|
||||
}
|
||||
|
||||
if _, ok := vMap[vo.UserInputKey]; !ok {
|
||||
if _, ok := vMap["USER_INPUT"]; !ok {
|
||||
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
|
||||
Name: vo.UserInputKey,
|
||||
Name: "USER_INPUT",
|
||||
Type: vo.VariableTypeString,
|
||||
})
|
||||
}
|
||||
if _, ok := vMap[vo.ConversationNameKey]; !ok {
|
||||
if _, ok := vMap["CONVERSATION_NAME"]; !ok {
|
||||
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
|
||||
Name: vo.ConversationNameKey,
|
||||
Name: "CONVERSATION_NAME",
|
||||
Type: vo.VariableTypeString,
|
||||
DefaultValue: "Default",
|
||||
})
|
||||
|
||||
@ -22,11 +22,15 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@ -197,3 +201,214 @@ func isIncremental(prev version, next version) bool {
|
||||
|
||||
return next.Patch > prev.Patch
|
||||
}
|
||||
|
||||
func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) {
|
||||
visited := make(map[string]struct{})
|
||||
maxRounds := int64(0)
|
||||
err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds)
|
||||
return maxRounds, err
|
||||
}
|
||||
|
||||
func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
|
||||
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
return nil
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err)
|
||||
}
|
||||
|
||||
return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds)
|
||||
}
|
||||
|
||||
func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
|
||||
for _, node := range nodes {
|
||||
if node == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory {
|
||||
if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds {
|
||||
*maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound
|
||||
}
|
||||
} else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil {
|
||||
param := node.Data.Inputs.LLMParam
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return err
|
||||
}
|
||||
var chatHistoryEnabled bool
|
||||
var chatHistoryRound int64
|
||||
for _, param := range llmParam {
|
||||
switch param.Name {
|
||||
case "enableChatHistory":
|
||||
if val, ok := param.Input.Value.Content.(bool); ok {
|
||||
b := val
|
||||
chatHistoryEnabled = b
|
||||
}
|
||||
case "chatHistoryRound":
|
||||
if strVal, ok := param.Input.Value.Content.(string); ok {
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatHistoryRound = int64Val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chatHistoryEnabled {
|
||||
if chatHistoryRound > *maxRounds {
|
||||
*maxRounds = chatHistoryRound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil
|
||||
if isSubWorkflow {
|
||||
workflowIDStr := node.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
|
||||
}
|
||||
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: node.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
|
||||
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(node.Blocks) > 0 {
|
||||
if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
visited := make(map[string]struct{})
|
||||
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
return 0, nil
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
maxRounds := int64(0)
|
||||
c := &vo.Canvas{}
|
||||
if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
}
|
||||
var (
|
||||
n *vo.Node
|
||||
nodeFinder func(nodes []*vo.Node) *vo.Node
|
||||
)
|
||||
nodeFinder = func(nodes []*vo.Node) *vo.Node {
|
||||
for i := range nodes {
|
||||
if nodes[i].ID == nodeID {
|
||||
return nodes[i]
|
||||
}
|
||||
if len(nodes[i].Blocks) > 0 {
|
||||
if n := nodeFinder(nodes[i].Blocks); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
n = nodeFinder(c.Nodes)
|
||||
if n.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if n.Data == nil || n.Data.Inputs == nil {
|
||||
return 0, nil
|
||||
}
|
||||
param := n.Data.Inputs.LLMParam
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var chatHistoryEnabled bool
|
||||
var chatHistoryRound int64
|
||||
for _, param := range llmParam {
|
||||
switch param.Name {
|
||||
case "enableChatHistory":
|
||||
if val, ok := param.Input.Value.Content.(bool); ok {
|
||||
b := val
|
||||
chatHistoryEnabled = b
|
||||
}
|
||||
case "chatHistoryRound":
|
||||
if strVal, ok := param.Input.Value.Content.(string); ok {
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
chatHistoryRound = int64Val
|
||||
}
|
||||
}
|
||||
}
|
||||
if chatHistoryEnabled {
|
||||
return chatHistoryRound, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() {
|
||||
if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory {
|
||||
return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
if n.Data != nil && n.Data.Inputs != nil {
|
||||
workflowIDStr := n.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
return 0, nil
|
||||
}
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err)
|
||||
}
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: n.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxRounds, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(n.Blocks) > 0 {
|
||||
if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return maxRounds, nil
|
||||
}
|
||||
|
||||
@ -289,5 +289,6 @@ require (
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/eino-contrib/jsonschema v1.0.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
)
|
||||
|
||||
@ -18,15 +18,10 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrObjectNotFound = errors.New("object not found")
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/storage/storage_mock.go -package mock -source storage.go Factory
|
||||
type Storage interface {
|
||||
// PutObject puts the object with the specified key.
|
||||
@ -73,10 +68,10 @@ type ListObjectsPaginatedOutput struct {
|
||||
}
|
||||
|
||||
type FileInfo struct {
|
||||
Key string `json:"key"`
|
||||
LastModified time.Time `json:"last_modified"`
|
||||
ETag string `json:"etag"`
|
||||
Size int64 `json:"size"`
|
||||
URL string `json:"url"`
|
||||
Tagging map[string]string `json:"tagging"`
|
||||
Key string
|
||||
LastModified time.Time
|
||||
ETag string
|
||||
Size int64
|
||||
URL string
|
||||
Tagging map[string]string
|
||||
}
|
||||
|
||||
@ -64,16 +64,15 @@ func (c *OceanBaseClient) BatchInsertVectors(ctx context.Context, collectionName
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) DeleteVector(ctx context.Context, collectionName string, vectorID string) error {
|
||||
return c.official.GetDB().WithContext(ctx).Table(collectionName).Where("vector_id = ?", vectorID).Delete(nil).Error
|
||||
return c.official.GetDB().WithContext(ctx).Exec("DELETE FROM "+collectionName+" WHERE vector_id = ?", vectorID).Error
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) InitDatabase(ctx context.Context) error {
|
||||
var result int
|
||||
return c.official.GetDB().WithContext(ctx).Raw("SELECT 1").Scan(&result).Error
|
||||
return c.official.GetDB().WithContext(ctx).Exec("SELECT 1").Error
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) DropCollection(ctx context.Context, collectionName string) error {
|
||||
return c.official.GetDB().WithContext(ctx).Migrator().DropTable(collectionName)
|
||||
return c.official.GetDB().WithContext(ctx).Exec("DROP TABLE IF EXISTS " + collectionName).Error
|
||||
}
|
||||
|
||||
type SearchStrategy interface {
|
||||
|
||||
@ -43,15 +43,6 @@ type VectorResult struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type VectorRecord struct {
|
||||
VectorID string `gorm:"column:vector_id;primaryKey"`
|
||||
Content string `gorm:"column:content;type:text;not null"`
|
||||
Metadata string `gorm:"column:metadata;type:json"`
|
||||
Embedding string `gorm:"column:embedding;type:vector;not null"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:CURRENT_TIMESTAMP"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"`
|
||||
}
|
||||
|
||||
type CollectionInfo struct {
|
||||
Name string `json:"name"`
|
||||
Dimension int `json:"dimension"`
|
||||
@ -92,23 +83,21 @@ func (c *OceanBaseOfficialClient) setVectorParameters() error {
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) CreateCollection(ctx context.Context, collectionName string, dimension int) error {
|
||||
if !c.db.WithContext(ctx).Migrator().HasTable(collectionName) {
|
||||
createTableSQL := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
vector_id VARCHAR(255) PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSON,
|
||||
embedding VECTOR(%d) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_created_at (created_at),
|
||||
INDEX idx_content (content(100))
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
||||
`, collectionName, dimension)
|
||||
createTableSQL := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
vector_id VARCHAR(255) PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSON,
|
||||
embedding VECTOR(%d) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_created_at (created_at),
|
||||
INDEX idx_content (content(100))
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
||||
`, collectionName, dimension)
|
||||
|
||||
if err := c.db.WithContext(ctx).Exec(createTableSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create table: %v", err)
|
||||
}
|
||||
if err := c.db.WithContext(ctx).Exec(createTableSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create table: %v", err)
|
||||
}
|
||||
|
||||
createIndexSQL := fmt.Sprintf(`
|
||||
@ -147,19 +136,30 @@ func (c *OceanBaseOfficialClient) InsertVectors(ctx context.Context, collectionN
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) insertBatch(ctx context.Context, collectionName string, batch []VectorResult) error {
|
||||
records := make([]VectorRecord, len(batch))
|
||||
for i, vector := range batch {
|
||||
records[i] = VectorRecord{
|
||||
VectorID: vector.VectorID,
|
||||
Content: vector.Content,
|
||||
Metadata: vector.Metadata,
|
||||
Embedding: c.vectorToString(vector.Embedding),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
placeholders := make([]string, len(batch))
|
||||
values := make([]interface{}, 0, len(batch)*5)
|
||||
|
||||
for j, vector := range batch {
|
||||
placeholders[j] = "(?, ?, ?, ?, NOW())"
|
||||
values = append(values,
|
||||
vector.VectorID,
|
||||
vector.Content,
|
||||
vector.Metadata,
|
||||
c.vectorToString(vector.Embedding),
|
||||
)
|
||||
}
|
||||
|
||||
return c.db.WithContext(ctx).Table(collectionName).Save(&records).Error
|
||||
sql := fmt.Sprintf(`
|
||||
INSERT INTO %s (vector_id, content, metadata, embedding, created_at)
|
||||
VALUES %s
|
||||
ON DUPLICATE KEY UPDATE
|
||||
content = VALUES(content),
|
||||
metadata = VALUES(metadata),
|
||||
embedding = VALUES(embedding),
|
||||
updated_at = NOW()
|
||||
`, collectionName, strings.Join(placeholders, ","))
|
||||
|
||||
return c.db.WithContext(ctx).Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) SearchVectors(
|
||||
@ -341,28 +341,24 @@ func (c *OceanBaseOfficialClient) DebugCollectionData(ctx context.Context, colle
|
||||
log.Printf("[Debug] Collection '%s' exists with %d vectors", collectionName, count)
|
||||
|
||||
log.Printf("[Debug] Sample data from collection '%s':", collectionName)
|
||||
var samples []struct {
|
||||
VectorID string `gorm:"column:vector_id"`
|
||||
Content string `gorm:"column:content"`
|
||||
CreatedAt time.Time `gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
err := c.db.WithContext(ctx).Table(collectionName).
|
||||
Select("vector_id, content, created_at").
|
||||
Order("created_at DESC").
|
||||
Limit(5).
|
||||
Find(&samples).Error
|
||||
|
||||
rows, err := c.db.WithContext(ctx).Raw(`
|
||||
SELECT vector_id, content, created_at
|
||||
FROM ` + collectionName + `
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 5
|
||||
`).Rows()
|
||||
if err != nil {
|
||||
log.Printf("[Debug] Failed to get sample data: %v", err)
|
||||
} else {
|
||||
for _, sample := range samples {
|
||||
contentPreview := sample.Content
|
||||
if len(contentPreview) > 50 {
|
||||
contentPreview = contentPreview[:50]
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var vectorID, content string
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&vectorID, &content, &createdAt); err != nil {
|
||||
log.Printf("[Debug] Failed to scan sample row: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[Debug] Sample: ID=%s, Content=%s, Created=%s",
|
||||
sample.VectorID, contentPreview, sample.CreatedAt)
|
||||
log.Printf("[Debug] Sample: ID=%s, Content=%s, Created=%s", vectorID, content[:min(50, len(content))], createdAt)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ func AssembleFileUrl(ctx context.Context, urlExpire *int64, files []*storage.Fil
|
||||
taskGroup := taskgroup.NewTaskGroup(ctx, 5)
|
||||
for idx := range files {
|
||||
f := files[idx]
|
||||
expire := int64(7 * 60 * 60 * 24)
|
||||
expire := int64(60 * 60 * 24)
|
||||
if urlExpire != nil && *urlExpire > 0 {
|
||||
expire = *urlExpire
|
||||
}
|
||||
|
||||
63
backend/infra/impl/storage/internal/proxy/proxy.go
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func CheckIfNeedReplaceHost(ctx context.Context, originURLStr string) (ok bool, proxyURL string) {
|
||||
// url parse
|
||||
originURL, err := url.Parse(originURLStr)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "[CheckIfNeedReplaceHost] url parse failed, err: %v", err)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
proxyPort := os.Getenv(consts.MinIOProxyEndpoint) // :8889
|
||||
if proxyPort == "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
currentHost, ok := ctxcache.Get[string](ctx, consts.HostKeyInCtx)
|
||||
if !ok {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
currentScheme, ok := ctxcache.Get[string](ctx, consts.RequestSchemeKeyInCtx)
|
||||
if !ok {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(currentHost)
|
||||
if err != nil {
|
||||
host = currentHost
|
||||
}
|
||||
|
||||
minioProxyHost := host + proxyPort
|
||||
originURL.Host = minioProxyHost
|
||||
originURL.Scheme = currentScheme
|
||||
logs.CtxDebugf(ctx, "[CheckIfNeedReplaceHost] reset originURL.String = %s", originURL.String())
|
||||
return true, originURL.String()
|
||||
}
|
||||
@ -30,6 +30,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/internal/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/internal/proxy"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
@ -231,6 +232,12 @@ func (m *minioClient) GetObjectUrl(ctx context.Context, objectKey string, opts .
|
||||
return "", fmt.Errorf("GetObjectUrl failed: %v", err)
|
||||
}
|
||||
|
||||
// logs.CtxDebugf(ctx, "[GetObjectUrl] origin presignedURL.String = %s", presignedURL.String())
|
||||
ok, proxyURL := proxy.CheckIfNeedReplaceHost(ctx, presignedURL.String())
|
||||
if ok {
|
||||
return proxyURL, nil
|
||||
}
|
||||
|
||||
return presignedURL.String(), nil
|
||||
}
|
||||
|
||||
@ -310,7 +317,7 @@ func (m *minioClient) HeadObject(ctx context.Context, objectKey string, opts ...
|
||||
stat, err := m.client.StatObject(ctx, m.bucketName, objectKey, minio.StatObjectOptions{})
|
||||
if err != nil {
|
||||
if minio.ToErrorResponse(err).Code == "NoSuchKey" {
|
||||
return nil, storage.ErrObjectNotFound
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("HeadObject failed for key %s: %w", objectKey, err)
|
||||
|
||||
@ -32,6 +32,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/internal/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/internal/proxy"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
@ -229,26 +230,21 @@ func (t *s3Client) GetObjectUrl(ctx context.Context, objectKey string, opts ...s
|
||||
bucket := t.bucketName
|
||||
presignClient := s3.NewPresignClient(client)
|
||||
|
||||
opt := storage.GetOption{}
|
||||
for _, optFn := range opts {
|
||||
optFn(&opt)
|
||||
}
|
||||
|
||||
expire := int64(60 * 60 * 24)
|
||||
if opt.Expire > 0 {
|
||||
expire = opt.Expire
|
||||
}
|
||||
|
||||
req, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(objectKey),
|
||||
}, func(options *s3.PresignOptions) {
|
||||
options.Expires = time.Duration(expire) * time.Second
|
||||
options.Expires = time.Duration(60*60*24) * time.Second
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get object presigned url failed: %v", err)
|
||||
}
|
||||
|
||||
ok, proxyURL := proxy.CheckIfNeedReplaceHost(ctx, req.URL)
|
||||
if ok {
|
||||
return proxyURL, nil
|
||||
}
|
||||
|
||||
return req.URL, nil
|
||||
}
|
||||
|
||||
@ -385,7 +381,7 @@ func (t *s3Client) HeadObject(ctx context.Context, objectKey string, opts ...sto
|
||||
if err != nil {
|
||||
var nsk *types.NotFound
|
||||
if errors.As(err, &nsk) {
|
||||
return nil, storage.ErrObjectNotFound
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/internal/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/internal/proxy"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@ -246,19 +247,9 @@ func (t *tosClient) GetObjectUrl(ctx context.Context, objectKey string, opts ...
|
||||
client := t.client
|
||||
bucketName := t.bucketName
|
||||
|
||||
opt := storage.GetOption{}
|
||||
for _, optFn := range opts {
|
||||
optFn(&opt)
|
||||
}
|
||||
|
||||
expire := int64(7 * 24 * 60 * 60)
|
||||
if opt.Expire > 0 {
|
||||
expire = opt.Expire
|
||||
}
|
||||
|
||||
output, err := client.PreSignedURL(&tos.PreSignedURLInput{
|
||||
HTTPMethod: enum.HttpMethodGet,
|
||||
Expires: expire,
|
||||
Expires: 60 * 60 * 24,
|
||||
Bucket: bucketName,
|
||||
Key: objectKey,
|
||||
})
|
||||
@ -266,6 +257,11 @@ func (t *tosClient) GetObjectUrl(ctx context.Context, objectKey string, opts ...
|
||||
return "", err
|
||||
}
|
||||
|
||||
ok, proxyURL := proxy.CheckIfNeedReplaceHost(ctx, output.SignedUrl)
|
||||
if ok {
|
||||
return proxyURL, nil
|
||||
}
|
||||
|
||||
return output.SignedUrl, nil
|
||||
}
|
||||
|
||||
@ -393,7 +389,7 @@ func (t *tosClient) HeadObject(ctx context.Context, objectKey string, opts ...st
|
||||
if err != nil {
|
||||
if serverErr, ok := err.(*tos.TosServerError); ok {
|
||||
if serverErr.StatusCode == http.StatusNotFound {
|
||||
return nil, storage.ErrObjectNotFound
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
|
||||
@ -0,0 +1,355 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: domain/agent/singleagent/service/single_agent.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination internal/mock/domain/agent/singleagent/single_agent_mock.go --package mock -source domain/agent/singleagent/service/single_agent.go
|
||||
//
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
schema "github.com/cloudwego/eino/schema"
|
||||
playground "github.com/coze-dev/coze-studio/backend/api/model/playground"
|
||||
entity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockSingleAgent is a mock of SingleAgent interface.
|
||||
type MockSingleAgent struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSingleAgentMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockSingleAgentMockRecorder is the mock recorder for MockSingleAgent.
|
||||
type MockSingleAgentMockRecorder struct {
|
||||
mock *MockSingleAgent
|
||||
}
|
||||
|
||||
// NewMockSingleAgent creates a new mock instance.
|
||||
func NewMockSingleAgent(ctrl *gomock.Controller) *MockSingleAgent {
|
||||
mock := &MockSingleAgent{ctrl: ctrl}
|
||||
mock.recorder = &MockSingleAgentMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSingleAgent) EXPECT() *MockSingleAgentMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateSingleAgent mocks base method.
|
||||
func (m *MockSingleAgent) CreateSingleAgent(ctx context.Context, connectorID int64, version string, e *entity.SingleAgent) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSingleAgent", ctx, connectorID, version, e)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateSingleAgent indicates an expected call of CreateSingleAgent.
|
||||
func (mr *MockSingleAgentMockRecorder) CreateSingleAgent(ctx, connectorID, version, e any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSingleAgent", reflect.TypeOf((*MockSingleAgent)(nil).CreateSingleAgent), ctx, connectorID, version, e)
|
||||
}
|
||||
|
||||
// CreateSingleAgentDraft mocks base method.
|
||||
func (m *MockSingleAgent) CreateSingleAgentDraft(ctx context.Context, creatorID int64, draft *entity.SingleAgent) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSingleAgentDraft", ctx, creatorID, draft)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateSingleAgentDraft indicates an expected call of CreateSingleAgentDraft.
|
||||
func (mr *MockSingleAgentMockRecorder) CreateSingleAgentDraft(ctx, creatorID, draft any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).CreateSingleAgentDraft), ctx, creatorID, draft)
|
||||
}
|
||||
|
||||
// CreateSingleAgentDraftWithID mocks base method.
|
||||
func (m *MockSingleAgent) CreateSingleAgentDraftWithID(ctx context.Context, creatorID, agentID int64, draft *entity.SingleAgent) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSingleAgentDraftWithID", ctx, creatorID, agentID, draft)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateSingleAgentDraftWithID indicates an expected call of CreateSingleAgentDraftWithID.
|
||||
func (mr *MockSingleAgentMockRecorder) CreateSingleAgentDraftWithID(ctx, creatorID, agentID, draft any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSingleAgentDraftWithID", reflect.TypeOf((*MockSingleAgent)(nil).CreateSingleAgentDraftWithID), ctx, creatorID, agentID, draft)
|
||||
}
|
||||
|
||||
// DeleteAgentDraft mocks base method.
|
||||
func (m *MockSingleAgent) DeleteAgentDraft(ctx context.Context, spaceID, agentID int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAgentDraft", ctx, spaceID, agentID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAgentDraft indicates an expected call of DeleteAgentDraft.
|
||||
func (mr *MockSingleAgentMockRecorder) DeleteAgentDraft(ctx, spaceID, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).DeleteAgentDraft), ctx, spaceID, agentID)
|
||||
}
|
||||
|
||||
// DuplicateInMemory mocks base method.
|
||||
func (m *MockSingleAgent) DuplicateInMemory(ctx context.Context, req *entity.DuplicateInfo) (*entity.SingleAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DuplicateInMemory", ctx, req)
|
||||
ret0, _ := ret[0].(*entity.SingleAgent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DuplicateInMemory indicates an expected call of DuplicateInMemory.
|
||||
func (mr *MockSingleAgentMockRecorder) DuplicateInMemory(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DuplicateInMemory", reflect.TypeOf((*MockSingleAgent)(nil).DuplicateInMemory), ctx, req)
|
||||
}
|
||||
|
||||
// GetAgentDraftDisplayInfo mocks base method.
|
||||
func (m *MockSingleAgent) GetAgentDraftDisplayInfo(ctx context.Context, userID, agentID int64) (*entity.AgentDraftDisplayInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentDraftDisplayInfo", ctx, userID, agentID)
|
||||
ret0, _ := ret[0].(*entity.AgentDraftDisplayInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentDraftDisplayInfo indicates an expected call of GetAgentDraftDisplayInfo.
|
||||
func (mr *MockSingleAgentMockRecorder) GetAgentDraftDisplayInfo(ctx, userID, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentDraftDisplayInfo", reflect.TypeOf((*MockSingleAgent)(nil).GetAgentDraftDisplayInfo), ctx, userID, agentID)
|
||||
}
|
||||
|
||||
// GetAgentPopupCount mocks base method.
|
||||
func (m *MockSingleAgent) GetAgentPopupCount(ctx context.Context, uid, agentID int64, agentPopupType playground.BotPopupType) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentPopupCount", ctx, uid, agentID, agentPopupType)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentPopupCount indicates an expected call of GetAgentPopupCount.
|
||||
func (mr *MockSingleAgentMockRecorder) GetAgentPopupCount(ctx, uid, agentID, agentPopupType any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentPopupCount", reflect.TypeOf((*MockSingleAgent)(nil).GetAgentPopupCount), ctx, uid, agentID, agentPopupType)
|
||||
}
|
||||
|
||||
// GetPublishConnectorList mocks base method.
|
||||
func (m *MockSingleAgent) GetPublishConnectorList(ctx context.Context, agentID int64) (*entity.PublishConnectorData, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPublishConnectorList", ctx, agentID)
|
||||
ret0, _ := ret[0].(*entity.PublishConnectorData)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPublishConnectorList indicates an expected call of GetPublishConnectorList.
|
||||
func (mr *MockSingleAgentMockRecorder) GetPublishConnectorList(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublishConnectorList", reflect.TypeOf((*MockSingleAgent)(nil).GetPublishConnectorList), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetPublishedInfo mocks base method.
|
||||
func (m *MockSingleAgent) GetPublishedInfo(ctx context.Context, agentID int64) (*entity.PublishInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPublishedInfo", ctx, agentID)
|
||||
ret0, _ := ret[0].(*entity.PublishInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPublishedInfo indicates an expected call of GetPublishedInfo.
|
||||
func (mr *MockSingleAgentMockRecorder) GetPublishedInfo(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublishedInfo", reflect.TypeOf((*MockSingleAgent)(nil).GetPublishedInfo), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetPublishedTime mocks base method.
|
||||
func (m *MockSingleAgent) GetPublishedTime(ctx context.Context, agentID int64) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPublishedTime", ctx, agentID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPublishedTime indicates an expected call of GetPublishedTime.
|
||||
func (mr *MockSingleAgentMockRecorder) GetPublishedTime(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublishedTime", reflect.TypeOf((*MockSingleAgent)(nil).GetPublishedTime), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetSingleAgent mocks base method.
|
||||
func (m *MockSingleAgent) GetSingleAgent(ctx context.Context, agentID int64, version string) (*entity.SingleAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSingleAgent", ctx, agentID, version)
|
||||
ret0, _ := ret[0].(*entity.SingleAgent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSingleAgent indicates an expected call of GetSingleAgent.
|
||||
func (mr *MockSingleAgentMockRecorder) GetSingleAgent(ctx, agentID, version any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSingleAgent", reflect.TypeOf((*MockSingleAgent)(nil).GetSingleAgent), ctx, agentID, version)
|
||||
}
|
||||
|
||||
// GetSingleAgentDraft mocks base method.
|
||||
func (m *MockSingleAgent) GetSingleAgentDraft(ctx context.Context, agentID int64) (*entity.SingleAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSingleAgentDraft", ctx, agentID)
|
||||
ret0, _ := ret[0].(*entity.SingleAgent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSingleAgentDraft indicates an expected call of GetSingleAgentDraft.
|
||||
func (mr *MockSingleAgentMockRecorder) GetSingleAgentDraft(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).GetSingleAgentDraft), ctx, agentID)
|
||||
}
|
||||
|
||||
// IncrAgentPopupCount mocks base method.
|
||||
func (m *MockSingleAgent) IncrAgentPopupCount(ctx context.Context, uid, agentID int64, agentPopupType playground.BotPopupType) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IncrAgentPopupCount", ctx, uid, agentID, agentPopupType)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IncrAgentPopupCount indicates an expected call of IncrAgentPopupCount.
|
||||
func (mr *MockSingleAgentMockRecorder) IncrAgentPopupCount(ctx, uid, agentID, agentPopupType any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrAgentPopupCount", reflect.TypeOf((*MockSingleAgent)(nil).IncrAgentPopupCount), ctx, uid, agentID, agentPopupType)
|
||||
}
|
||||
|
||||
// ListAgentPublishHistory mocks base method.
|
||||
func (m *MockSingleAgent) ListAgentPublishHistory(ctx context.Context, agentID int64, pageIndex, pageSize int32, connectorID *int64) ([]*entity.SingleAgentPublish, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAgentPublishHistory", ctx, agentID, pageIndex, pageSize, connectorID)
|
||||
ret0, _ := ret[0].([]*entity.SingleAgentPublish)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAgentPublishHistory indicates an expected call of ListAgentPublishHistory.
|
||||
func (mr *MockSingleAgentMockRecorder) ListAgentPublishHistory(ctx, agentID, pageIndex, pageSize, connectorID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAgentPublishHistory", reflect.TypeOf((*MockSingleAgent)(nil).ListAgentPublishHistory), ctx, agentID, pageIndex, pageSize, connectorID)
|
||||
}
|
||||
|
||||
// MGetSingleAgentDraft mocks base method.
|
||||
func (m *MockSingleAgent) MGetSingleAgentDraft(ctx context.Context, agentIDs []int64) ([]*entity.SingleAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetSingleAgentDraft", ctx, agentIDs)
|
||||
ret0, _ := ret[0].([]*entity.SingleAgent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetSingleAgentDraft indicates an expected call of MGetSingleAgentDraft.
|
||||
func (mr *MockSingleAgentMockRecorder) MGetSingleAgentDraft(ctx, agentIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).MGetSingleAgentDraft), ctx, agentIDs)
|
||||
}
|
||||
|
||||
// ObtainAgentByIdentity mocks base method.
|
||||
func (m *MockSingleAgent) ObtainAgentByIdentity(ctx context.Context, identity *entity.AgentIdentity) (*entity.SingleAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ObtainAgentByIdentity", ctx, identity)
|
||||
ret0, _ := ret[0].(*entity.SingleAgent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ObtainAgentByIdentity indicates an expected call of ObtainAgentByIdentity.
|
||||
func (mr *MockSingleAgentMockRecorder) ObtainAgentByIdentity(ctx, identity any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObtainAgentByIdentity", reflect.TypeOf((*MockSingleAgent)(nil).ObtainAgentByIdentity), ctx, identity)
|
||||
}
|
||||
|
||||
// SavePublishRecord mocks base method.
|
||||
func (m *MockSingleAgent) SavePublishRecord(ctx context.Context, p *entity.SingleAgentPublish, e *entity.SingleAgent) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SavePublishRecord", ctx, p, e)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SavePublishRecord indicates an expected call of SavePublishRecord.
|
||||
func (mr *MockSingleAgentMockRecorder) SavePublishRecord(ctx, p, e any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePublishRecord", reflect.TypeOf((*MockSingleAgent)(nil).SavePublishRecord), ctx, p, e)
|
||||
}
|
||||
|
||||
// StreamExecute mocks base method.
|
||||
func (m *MockSingleAgent) StreamExecute(ctx context.Context, req *entity.ExecuteRequest) (*schema.StreamReader[*entity.AgentEvent], error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StreamExecute", ctx, req)
|
||||
ret0, _ := ret[0].(*schema.StreamReader[*entity.AgentEvent])
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// StreamExecute indicates an expected call of StreamExecute.
|
||||
func (mr *MockSingleAgentMockRecorder) StreamExecute(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamExecute", reflect.TypeOf((*MockSingleAgent)(nil).StreamExecute), ctx, req)
|
||||
}
|
||||
|
||||
// UpdateAgentDraftDisplayInfo mocks base method.
|
||||
func (m *MockSingleAgent) UpdateAgentDraftDisplayInfo(ctx context.Context, userID int64, e *entity.AgentDraftDisplayInfo) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAgentDraftDisplayInfo", ctx, userID, e)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAgentDraftDisplayInfo indicates an expected call of UpdateAgentDraftDisplayInfo.
|
||||
func (mr *MockSingleAgentMockRecorder) UpdateAgentDraftDisplayInfo(ctx, userID, e any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAgentDraftDisplayInfo", reflect.TypeOf((*MockSingleAgent)(nil).UpdateAgentDraftDisplayInfo), ctx, userID, e)
|
||||
}
|
||||
|
||||
// UpdateSingleAgentDraft mocks base method.
|
||||
func (m *MockSingleAgent) UpdateSingleAgentDraft(ctx context.Context, agentInfo *entity.SingleAgent) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateSingleAgentDraft", ctx, agentInfo)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateSingleAgentDraft indicates an expected call of UpdateSingleAgentDraft.
|
||||
func (mr *MockSingleAgentMockRecorder) UpdateSingleAgentDraft(ctx, agentInfo any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSingleAgentDraft", reflect.TypeOf((*MockSingleAgent)(nil).UpdateSingleAgentDraft), ctx, agentInfo)
|
||||
}
|
||||
@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: domain/conversation/agentrun/service/agent_run.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination internal/mock/domain/conversation/agentrun/agent_run_mock.go --package mock_agentrun -source domain/conversation/agentrun/service/agent_run.go
|
||||
//
|
||||
|
||||
// Package mock_agentrun is a generated GoMock package.
|
||||
package mock_agentrun
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
schema "github.com/cloudwego/eino/schema"
|
||||
entity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockRun is a mock of Run interface.
|
||||
type MockRun struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockRunMockRecorder is the mock recorder for MockRun.
|
||||
type MockRunMockRecorder struct {
|
||||
mock *MockRun
|
||||
}
|
||||
|
||||
// NewMockRun creates a new mock instance.
|
||||
func NewMockRun(ctrl *gomock.Controller) *MockRun {
|
||||
mock := &MockRun{ctrl: ctrl}
|
||||
mock.recorder = &MockRunMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRun) EXPECT() *MockRunMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AgentRun mocks base method.
|
||||
func (m *MockRun) AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AgentRun", ctx, req)
|
||||
ret0, _ := ret[0].(*schema.StreamReader[*entity.AgentRunResponse])
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AgentRun indicates an expected call of AgentRun.
|
||||
func (mr *MockRunMockRecorder) AgentRun(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentRun", reflect.TypeOf((*MockRun)(nil).AgentRun), ctx, req)
|
||||
}
|
||||
|
||||
// Cancel mocks base method.
|
||||
func (m *MockRun) Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Cancel", ctx, req)
|
||||
ret0, _ := ret[0].(*entity.RunRecordMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Cancel indicates an expected call of Cancel.
|
||||
func (mr *MockRunMockRecorder) Cancel(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockRun)(nil).Cancel), ctx, req)
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockRun) Create(ctx context.Context, runRecord *entity.AgentRunMeta) (*entity.RunRecordMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", ctx, runRecord)
|
||||
ret0, _ := ret[0].(*entity.RunRecordMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockRunMockRecorder) Create(ctx, runRecord any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockRun)(nil).Create), ctx, runRecord)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockRun) Delete(ctx context.Context, runID []int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, runID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockRunMockRecorder) Delete(ctx, runID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockRun)(nil).Delete), ctx, runID)
|
||||
}
|
||||
|
||||
// GetByID mocks base method.
|
||||
func (m *MockRun) GetByID(ctx context.Context, runID int64) (*entity.RunRecordMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByID", ctx, runID)
|
||||
ret0, _ := ret[0].(*entity.RunRecordMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByID indicates an expected call of GetByID.
|
||||
func (mr *MockRunMockRecorder) GetByID(ctx, runID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockRun)(nil).GetByID), ctx, runID)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockRun) List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List", ctx, ListMeta)
|
||||
ret0, _ := ret[0].([]*entity.RunRecordMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockRunMockRecorder) List(ctx, ListMeta any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockRun)(nil).List), ctx, ListMeta)
|
||||
}
|
||||
@ -0,0 +1,163 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: domain/conversation/conversation/service/conversation.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination internal/mock/domain/conversation/conversation/conversation_mock.go --package mock_conversation -source domain/conversation/conversation/service/conversation.go
|
||||
//
|
||||
|
||||
// Package mock_conversation is a generated GoMock package.
|
||||
package mock_conversation
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
entity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockConversation is a mock of Conversation interface.
|
||||
type MockConversation struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConversationMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockConversationMockRecorder is the mock recorder for MockConversation.
|
||||
type MockConversationMockRecorder struct {
|
||||
mock *MockConversation
|
||||
}
|
||||
|
||||
// NewMockConversation creates a new mock instance.
|
||||
func NewMockConversation(ctrl *gomock.Controller) *MockConversation {
|
||||
mock := &MockConversation{ctrl: ctrl}
|
||||
mock.recorder = &MockConversationMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockConversation) EXPECT() *MockConversationMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockConversation) Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", ctx, req)
|
||||
ret0, _ := ret[0].(*entity.Conversation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockConversationMockRecorder) Create(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockConversation)(nil).Create), ctx, req)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockConversation) Delete(ctx context.Context, id int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockConversationMockRecorder) Delete(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockConversation)(nil).Delete), ctx, id)
|
||||
}
|
||||
|
||||
// GetByID mocks base method.
|
||||
func (m *MockConversation) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByID", ctx, id)
|
||||
ret0, _ := ret[0].(*entity.Conversation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByID indicates an expected call of GetByID.
|
||||
func (mr *MockConversationMockRecorder) GetByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockConversation)(nil).GetByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetCurrentConversation mocks base method.
|
||||
func (m *MockConversation) GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCurrentConversation", ctx, req)
|
||||
ret0, _ := ret[0].(*entity.Conversation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetCurrentConversation indicates an expected call of GetCurrentConversation.
|
||||
func (mr *MockConversationMockRecorder) GetCurrentConversation(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentConversation", reflect.TypeOf((*MockConversation)(nil).GetCurrentConversation), ctx, req)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockConversation) List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List", ctx, req)
|
||||
ret0, _ := ret[0].([]*entity.Conversation)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockConversationMockRecorder) List(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockConversation)(nil).List), ctx, req)
|
||||
}
|
||||
|
||||
// NewConversationCtx mocks base method.
|
||||
func (m *MockConversation) NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "NewConversationCtx", ctx, req)
|
||||
ret0, _ := ret[0].(*entity.NewConversationCtxResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// NewConversationCtx indicates an expected call of NewConversationCtx.
|
||||
func (mr *MockConversationMockRecorder) NewConversationCtx(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConversationCtx", reflect.TypeOf((*MockConversation)(nil).NewConversationCtx), ctx, req)
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockConversation) Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", ctx, req)
|
||||
ret0, _ := ret[0].(*entity.Conversation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockConversationMockRecorder) Update(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockConversation)(nil).Update), ctx, req)
|
||||
}
|
||||
132
backend/internal/mock/domain/shortcutcmd/shortcut_cmd_mock.go
Normal file
@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: domain/shortcutcmd/service/shortcut_cmd.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination internal/mock/domain/shortcutcmd/shortcut_cmd_mock.go --package mock_shortcutcmd -source domain/shortcutcmd/service/shortcut_cmd.go
|
||||
//
|
||||
|
||||
// Package mock_shortcutcmd is a generated GoMock package.
|
||||
package mock_shortcutcmd
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
entity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockShortcutCmd is a mock of ShortcutCmd interface.
|
||||
type MockShortcutCmd struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockShortcutCmdMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockShortcutCmdMockRecorder is the mock recorder for MockShortcutCmd.
|
||||
type MockShortcutCmdMockRecorder struct {
|
||||
mock *MockShortcutCmd
|
||||
}
|
||||
|
||||
// NewMockShortcutCmd creates a new mock instance.
|
||||
func NewMockShortcutCmd(ctrl *gomock.Controller) *MockShortcutCmd {
|
||||
mock := &MockShortcutCmd{ctrl: ctrl}
|
||||
mock.recorder = &MockShortcutCmdMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockShortcutCmd) EXPECT() *MockShortcutCmdMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateCMD mocks base method.
|
||||
func (m *MockShortcutCmd) CreateCMD(ctx context.Context, shortcut *entity.ShortcutCmd) (*entity.ShortcutCmd, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateCMD", ctx, shortcut)
|
||||
ret0, _ := ret[0].(*entity.ShortcutCmd)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateCMD indicates an expected call of CreateCMD.
|
||||
func (mr *MockShortcutCmdMockRecorder) CreateCMD(ctx, shortcut any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCMD", reflect.TypeOf((*MockShortcutCmd)(nil).CreateCMD), ctx, shortcut)
|
||||
}
|
||||
|
||||
// GetByCmdID mocks base method.
|
||||
func (m *MockShortcutCmd) GetByCmdID(ctx context.Context, cmdID int64, isOnline int32) (*entity.ShortcutCmd, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByCmdID", ctx, cmdID, isOnline)
|
||||
ret0, _ := ret[0].(*entity.ShortcutCmd)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByCmdID indicates an expected call of GetByCmdID.
|
||||
func (mr *MockShortcutCmdMockRecorder) GetByCmdID(ctx, cmdID, isOnline any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByCmdID", reflect.TypeOf((*MockShortcutCmd)(nil).GetByCmdID), ctx, cmdID, isOnline)
|
||||
}
|
||||
|
||||
// ListCMD mocks base method.
|
||||
func (m *MockShortcutCmd) ListCMD(ctx context.Context, lm *entity.ListMeta) ([]*entity.ShortcutCmd, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListCMD", ctx, lm)
|
||||
ret0, _ := ret[0].([]*entity.ShortcutCmd)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListCMD indicates an expected call of ListCMD.
|
||||
func (mr *MockShortcutCmdMockRecorder) ListCMD(ctx, lm any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListCMD", reflect.TypeOf((*MockShortcutCmd)(nil).ListCMD), ctx, lm)
|
||||
}
|
||||
|
||||
// PublishCMDs mocks base method.
|
||||
func (m *MockShortcutCmd) PublishCMDs(ctx context.Context, objID int64, cmdIDs []int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishCMDs", ctx, objID, cmdIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PublishCMDs indicates an expected call of PublishCMDs.
|
||||
func (mr *MockShortcutCmdMockRecorder) PublishCMDs(ctx, objID, cmdIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishCMDs", reflect.TypeOf((*MockShortcutCmd)(nil).PublishCMDs), ctx, objID, cmdIDs)
|
||||
}
|
||||
|
||||
// UpdateCMD mocks base method.
|
||||
func (m *MockShortcutCmd) UpdateCMD(ctx context.Context, shortcut *entity.ShortcutCmd) (*entity.ShortcutCmd, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateCMD", ctx, shortcut)
|
||||
ret0, _ := ret[0].(*entity.ShortcutCmd)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateCMD indicates an expected call of UpdateCMD.
|
||||
func (mr *MockShortcutCmdMockRecorder) UpdateCMD(ctx, shortcut any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCMD", reflect.TypeOf((*MockShortcutCmd)(nil).UpdateCMD), ctx, shortcut)
|
||||
}
|
||||
118
backend/internal/mock/domain/upload/upload_service_mock.go
Normal file
@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: domain/upload/service/interface.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination internal/mock/domain/upload/upload_service_mock.go --package mock_upload -source domain/upload/service/interface.go UploadService
|
||||
//
|
||||
|
||||
// Package mock_upload is a generated GoMock package.
|
||||
package mock_upload
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
service "github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUploadService is a mock of UploadService interface.
|
||||
type MockUploadService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUploadServiceMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockUploadServiceMockRecorder is the mock recorder for MockUploadService.
|
||||
type MockUploadServiceMockRecorder struct {
|
||||
mock *MockUploadService
|
||||
}
|
||||
|
||||
// NewMockUploadService creates a new mock instance.
|
||||
func NewMockUploadService(ctrl *gomock.Controller) *MockUploadService {
|
||||
mock := &MockUploadService{ctrl: ctrl}
|
||||
mock.recorder = &MockUploadServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockUploadService) EXPECT() *MockUploadServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetFile mocks base method.
|
||||
func (m *MockUploadService) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFile", ctx, req)
|
||||
ret0, _ := ret[0].(*service.GetFileResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFile indicates an expected call of GetFile.
|
||||
func (mr *MockUploadServiceMockRecorder) GetFile(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploadService)(nil).GetFile), ctx, req)
|
||||
}
|
||||
|
||||
// GetFiles mocks base method.
|
||||
func (m *MockUploadService) GetFiles(ctx context.Context, req *service.GetFilesRequest) (*service.GetFilesResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFiles", ctx, req)
|
||||
ret0, _ := ret[0].(*service.GetFilesResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFiles indicates an expected call of GetFiles.
|
||||
func (mr *MockUploadServiceMockRecorder) GetFiles(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFiles", reflect.TypeOf((*MockUploadService)(nil).GetFiles), ctx, req)
|
||||
}
|
||||
|
||||
// UploadFile mocks base method.
|
||||
func (m *MockUploadService) UploadFile(ctx context.Context, req *service.UploadFileRequest) (*service.UploadFileResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UploadFile", ctx, req)
|
||||
ret0, _ := ret[0].(*service.UploadFileResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UploadFile indicates an expected call of UploadFile.
|
||||
func (mr *MockUploadServiceMockRecorder) UploadFile(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFile", reflect.TypeOf((*MockUploadService)(nil).UploadFile), ctx, req)
|
||||
}
|
||||
|
||||
// UploadFiles mocks base method.
|
||||
func (m *MockUploadService) UploadFiles(ctx context.Context, req *service.UploadFilesRequest) (*service.UploadFilesResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UploadFiles", ctx, req)
|
||||
ret0, _ := ret[0].(*service.UploadFilesResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UploadFiles indicates an expected call of UploadFiles.
|
||||
func (mr *MockUploadServiceMockRecorder) UploadFiles(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadFiles", reflect.TypeOf((*MockUploadService)(nil).UploadFiles), ctx, req)
|
||||
}
|
||||
@ -22,6 +22,10 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
@ -37,6 +41,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
@ -55,6 +60,7 @@ func main() {
|
||||
panic("InitializeInfra failed, err=" + err.Error())
|
||||
}
|
||||
|
||||
asyncStartMinioProxyServer(ctx)
|
||||
startHttpServer()
|
||||
}
|
||||
|
||||
@ -154,3 +160,56 @@ func setCrashOutput() {
|
||||
crashFile, _ := os.Create("crash.log")
|
||||
debug.SetCrashOutput(crashFile, debug.CrashOptions{})
|
||||
}
|
||||
|
||||
// TODO: remove me later
|
||||
func asyncStartMinioProxyServer(ctx context.Context) {
|
||||
storageType := getEnv(consts.StorageType, "minio")
|
||||
proxyURL := getEnv(consts.MinIOAPIHost, "http://localhost:9000")
|
||||
|
||||
if storageType == "tos" {
|
||||
proxyURL = getEnv(consts.TOSBucketEndpoint, "https://opencoze.tos-cn-beijing.volces.com")
|
||||
}
|
||||
|
||||
if storageType == "s3" {
|
||||
proxyURL = getEnv(consts.S3BucketEndpoint, "")
|
||||
}
|
||||
|
||||
minioProxyEndpoint := getEnv(consts.MinIOProxyEndpoint, "")
|
||||
if len(minioProxyEndpoint) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
target, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
originDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
q := req.URL.Query()
|
||||
q.Del("x-wf-file_name")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
originDirector(req)
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
useSSL := getEnv(consts.UseSSL, "0")
|
||||
if useSSL == "1" {
|
||||
logs.Infof("Minio proxy server is listening on %s with SSL", minioProxyEndpoint)
|
||||
err := http.ListenAndServeTLS(minioProxyEndpoint,
|
||||
getEnv(consts.SSLCertFile, ""),
|
||||
getEnv(consts.SSLKeyFile, ""), proxy)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
logs.Infof("Minio proxy server is listening on %s", minioProxyEndpoint)
|
||||
err := http.ListenAndServe(minioProxyEndpoint, proxy)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -20,7 +20,7 @@ services:
|
||||
- ./data/mysql:/var/lib/mysql
|
||||
- ./volumes/mysql/schema.sql:/docker-entrypoint-initdb.d/init.sql
|
||||
- ./atlas/opencoze_latest_schema.hcl:/opencoze_latest_schema.hcl:ro
|
||||
entrypoint:
|
||||
entrypoint:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
@ -43,7 +43,7 @@ services:
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
echo 'MySQL is ready, installing Atlas CLI...'
|
||||
|
||||
if ! command -v atlas >/dev/null 2>&1; then
|
||||
@ -53,7 +53,7 @@ services:
|
||||
else
|
||||
echo 'Atlas CLI already installed'
|
||||
fi
|
||||
|
||||
|
||||
if [ -f '/opencoze_latest_schema.hcl' ]; then
|
||||
echo 'Running Atlas migrations...'
|
||||
ATLAS_URL="mysql://$${MYSQL_USER}:$${MYSQL_PASSWORD}@localhost:3306/$${MYSQL_DATABASE}"
|
||||
@ -274,7 +274,7 @@ services:
|
||||
|
||||
# Download plugin package locally
|
||||
echo 'Copying smartcn plugin...';
|
||||
cp /opt/bitnami/elasticsearch/analysis-smartcn.zip /tmp/analysis-smartcn.zip
|
||||
cp /opt/bitnami/elasticsearch/analysis-smartcn.zip /tmp/analysis-smartcn.zip
|
||||
|
||||
elasticsearch-plugin install file:///tmp/analysis-smartcn.zip
|
||||
if [[ "$$?" != "0" ]]; then
|
||||
|
||||
@ -252,7 +252,6 @@ services:
|
||||
OB_DATAFILE_SIZE: 1G
|
||||
OB_SYS_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_TENANT_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-cozeAi}
|
||||
ports:
|
||||
- '2881:2881'
|
||||
volumes:
|
||||
|
||||
@ -345,7 +345,6 @@ services:
|
||||
OB_DATAFILE_SIZE: 1G
|
||||
OB_SYS_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_TENANT_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-cozeAi}
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
ports:
|
||||
|
||||
@ -19,7 +19,7 @@ services:
|
||||
- ./data/mysql:/var/lib/mysql
|
||||
- ./volumes/mysql/schema.sql:/docker-entrypoint-initdb.d/init.sql
|
||||
- ./atlas/opencoze_latest_schema.hcl:/opencoze_latest_schema.hcl:ro
|
||||
entrypoint:
|
||||
entrypoint:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
@ -42,7 +42,7 @@ services:
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
echo 'MySQL is ready, installing Atlas CLI...'
|
||||
|
||||
if ! command -v atlas >/dev/null 2>&1; then
|
||||
@ -52,7 +52,7 @@ services:
|
||||
else
|
||||
echo 'Atlas CLI already installed'
|
||||
fi
|
||||
|
||||
|
||||
if [ -f '/opencoze_latest_schema.hcl' ]; then
|
||||
echo 'Running Atlas migrations...'
|
||||
ATLAS_URL="mysql://$${MYSQL_USER}:$${MYSQL_PASSWORD}@localhost:3306/$${MYSQL_DATABASE}"
|
||||
@ -161,7 +161,7 @@ services:
|
||||
|
||||
# Download plugin package locally
|
||||
echo 'Copying smartcn plugin...';
|
||||
cp /opt/bitnami/elasticsearch/analysis-smartcn.zip /tmp/analysis-smartcn.zip
|
||||
cp /opt/bitnami/elasticsearch/analysis-smartcn.zip /tmp/analysis-smartcn.zip
|
||||
|
||||
elasticsearch-plugin install file:///tmp/analysis-smartcn.zip
|
||||
if [[ "$$?" != "0" ]]; then
|
||||
|
||||
|
Before Width: | Height: | Size: 37 KiB |
|
Before Width: | Height: | Size: 28 KiB |
|
Before Width: | Height: | Size: 24 KiB |
|
Before Width: | Height: | Size: 26 KiB |
|
Before Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 37 KiB |
|
Before Width: | Height: | Size: 22 KiB |
|
Before Width: | Height: | Size: 22 KiB |
|
Before Width: | Height: | Size: 27 KiB |
|
Before Width: | Height: | Size: 30 KiB |
|
Before Width: | Height: | Size: 28 KiB |
|
Before Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 25 KiB |
|
Before Width: | Height: | Size: 25 KiB |
|
Before Width: | Height: | Size: 25 KiB |
|
Before Width: | Height: | Size: 20 KiB |
|
Before Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 25 KiB |