Compare commits
2 Commits
fix/coderu
...
feat/sync_
| Author | SHA1 | Date | |
|---|---|---|---|
| b0c6dd2046 | |||
| b05004f188 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -60,4 +60,3 @@ values-dev.yaml
|
||||
|
||||
*.tsbuildinfo
|
||||
|
||||
.coda/
|
||||
|
||||
@ -1459,6 +1459,37 @@ func TestTestResumeWithInputNode(t *testing.T) {
|
||||
resp := post[workflow.OpenAPIRunFlowResponse](r, syncRunReq)
|
||||
assert.Equal(t, int64(errno.ErrOpenAPIInterruptNotSupported), resp.Code)
|
||||
})
|
||||
|
||||
mockey.PatchConvey("test run, then sync resume", func() {
|
||||
ctx := t.Context()
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input": "unused initial input",
|
||||
})
|
||||
e := r.getProcess(id, exeID)
|
||||
assert.NotNil(t, e.event) // interrupted
|
||||
|
||||
exeInt64ID, _ := strconv.ParseInt(exeID, 10, 64)
|
||||
eventInt64ID, _ := strconv.ParseInt(e.event.ID, 10, 64)
|
||||
|
||||
result, _, err := appworkflow.GetWorkflowDomainSVC().SyncResume(ctx, &entity.ResumeRequest{
|
||||
ExecuteID: exeInt64ID,
|
||||
EventID: eventInt64ID,
|
||||
ResumeData: userInputStr,
|
||||
}, workflowModel.ExecuteConfig{
|
||||
Operator: 123,
|
||||
Mode: workflowModel.ExecuteModeDebug,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"input": "user input",
|
||||
"inputArr": nil,
|
||||
"field1": `["1","2"]`,
|
||||
}, mustUnmarshalToMap(t, *result.Output))
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@ -1823,64 +1854,173 @@ func TestInterruptWithinBatch(t *testing.T) {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
id := r.load("batch/batch_with_inner_interrupt.json")
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input_array": `["a","b"]`,
|
||||
"batch_concurrency": "2",
|
||||
mockey.PatchConvey("test run with async resume", func() {
|
||||
id := r.load("batch/batch_with_inner_interrupt.json")
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input_array": `["a","b"]`,
|
||||
"batch_concurrency": "2",
|
||||
})
|
||||
|
||||
e := r.getProcess(id, exeID)
|
||||
assert.Equal(t, workflow.EventType_InputNode, e.event.Type)
|
||||
|
||||
exeIDInt, _ := strconv.ParseInt(exeID, 0, 64)
|
||||
storeIEs, _ := workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 2, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e.event.ID, map[string]any{
|
||||
"input": "input 1",
|
||||
})
|
||||
|
||||
e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID))
|
||||
assert.Equal(t, workflow.EventType_InputNode, e2.event.Type)
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 2, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e2.event.ID, map[string]any{
|
||||
"input": "input 2",
|
||||
})
|
||||
|
||||
e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID))
|
||||
assert.Equal(t, workflow.EventType_Question, e3.event.Type)
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 2, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e3.event.ID, "answer 1")
|
||||
|
||||
e4 := r.getProcess(id, exeID, withPreviousEventID(e3.event.ID))
|
||||
assert.Equal(t, workflow.EventType_Question, e4.event.Type)
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 1, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e4.event.ID, "answer 2")
|
||||
|
||||
e5 := r.getProcess(id, exeID, withPreviousEventID(e4.event.ID))
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 0, len(storeIEs))
|
||||
e5.assertSuccess()
|
||||
|
||||
outputMap := mustUnmarshalToMap(t, e5.output)
|
||||
|
||||
if !reflect.DeepEqual(outputMap, map[string]any{
|
||||
"output": []any{"answer 1", "answer 2"},
|
||||
}) && !reflect.DeepEqual(outputMap, map[string]any{
|
||||
"output": []any{"answer 2", "answer 1"},
|
||||
}) {
|
||||
t.Errorf("output map not equal: %v", outputMap)
|
||||
}
|
||||
})
|
||||
mockey.PatchConvey("test run with sync resume", func() {
|
||||
id := r.load("batch/batch_with_inner_interrupt_for_debug_run.json")
|
||||
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input_array": `["a","b"]`,
|
||||
"batch_concurrency": "2",
|
||||
})
|
||||
|
||||
e := r.getProcess(id, exeID)
|
||||
assert.Equal(t, workflow.EventType_InputNode, e.event.Type)
|
||||
|
||||
data := map[string]any{
|
||||
"input": "123",
|
||||
}
|
||||
bs, _ := sonic.Marshal(data)
|
||||
|
||||
exeInt64ID, _ := strconv.ParseInt(exeID, 10, 64)
|
||||
eventInt64ID, _ := strconv.ParseInt(e.event.ID, 10, 64)
|
||||
|
||||
result, _, err := appworkflow.GetWorkflowDomainSVC().SyncResume(t.Context(), &entity.ResumeRequest{
|
||||
ExecuteID: exeInt64ID,
|
||||
EventID: eventInt64ID,
|
||||
ResumeData: string(bs),
|
||||
}, workflowModel.ExecuteConfig{
|
||||
Operator: 123,
|
||||
Mode: workflowModel.ExecuteModeDebug,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result.Status, entity.WorkflowInterrupted)
|
||||
assert.NotNil(t, result.InterruptEvents)
|
||||
data = map[string]any{
|
||||
"input": "456",
|
||||
}
|
||||
bs, _ = sonic.Marshal(data)
|
||||
|
||||
result, _, err = appworkflow.GetWorkflowDomainSVC().SyncResume(t.Context(), &entity.ResumeRequest{
|
||||
ExecuteID: exeInt64ID,
|
||||
EventID: result.InterruptEvents[0].ID,
|
||||
ResumeData: string(bs),
|
||||
}, workflowModel.ExecuteConfig{
|
||||
Operator: 123,
|
||||
Mode: workflowModel.ExecuteModeDebug,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": []any{"123", "456"},
|
||||
}, mustUnmarshalToMap(t, *result.Output))
|
||||
|
||||
})
|
||||
mockey.PatchConvey("node debug run with sync resume", func() {
|
||||
id := r.load("batch/batch_with_inner_interrupt_for_debug_run.json")
|
||||
|
||||
exeID := r.nodeDebug(id, "105709", withNDBatch(map[string]string{}))
|
||||
|
||||
e := r.getProcess(id, exeID)
|
||||
assert.Equal(t, workflow.EventType_InputNode, e.event.Type)
|
||||
|
||||
data := map[string]any{
|
||||
"input": "123",
|
||||
}
|
||||
bs, _ := sonic.Marshal(data)
|
||||
|
||||
exeInt64ID, _ := strconv.ParseInt(exeID, 10, 64)
|
||||
eventInt64ID, _ := strconv.ParseInt(e.event.ID, 10, 64)
|
||||
|
||||
result, _, err := appworkflow.GetWorkflowDomainSVC().SyncResume(t.Context(), &entity.ResumeRequest{
|
||||
ExecuteID: exeInt64ID,
|
||||
EventID: eventInt64ID,
|
||||
ResumeData: string(bs),
|
||||
}, workflowModel.ExecuteConfig{
|
||||
Operator: 123,
|
||||
Mode: workflowModel.ExecuteModeNodeDebug,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result.Status, entity.WorkflowInterrupted)
|
||||
assert.NotNil(t, result.InterruptEvents)
|
||||
data = map[string]any{
|
||||
"input": "456",
|
||||
}
|
||||
bs, _ = sonic.Marshal(data)
|
||||
|
||||
result, _, err = appworkflow.GetWorkflowDomainSVC().SyncResume(t.Context(), &entity.ResumeRequest{
|
||||
ExecuteID: exeInt64ID,
|
||||
EventID: result.InterruptEvents[0].ID,
|
||||
ResumeData: string(bs),
|
||||
}, workflowModel.ExecuteConfig{
|
||||
Operator: 123,
|
||||
Mode: workflowModel.ExecuteModeNodeDebug,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": []any{"123", "456"},
|
||||
}, mustUnmarshalToMap(t, *result.Output))
|
||||
fmt.Println(result, err)
|
||||
|
||||
})
|
||||
|
||||
e := r.getProcess(id, exeID)
|
||||
assert.Equal(t, workflow.EventType_InputNode, e.event.Type)
|
||||
|
||||
exeIDInt, _ := strconv.ParseInt(exeID, 0, 64)
|
||||
storeIEs, _ := workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 2, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e.event.ID, map[string]any{
|
||||
"input": "input 1",
|
||||
})
|
||||
|
||||
e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID))
|
||||
assert.Equal(t, workflow.EventType_InputNode, e2.event.Type)
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 2, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e2.event.ID, map[string]any{
|
||||
"input": "input 2",
|
||||
})
|
||||
|
||||
e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID))
|
||||
assert.Equal(t, workflow.EventType_Question, e3.event.Type)
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 2, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e3.event.ID, "answer 1")
|
||||
|
||||
e4 := r.getProcess(id, exeID, withPreviousEventID(e3.event.ID))
|
||||
assert.Equal(t, workflow.EventType_Question, e4.event.Type)
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 1, len(storeIEs))
|
||||
|
||||
r.testResume(id, exeID, e4.event.ID, "answer 2")
|
||||
|
||||
e5 := r.getProcess(id, exeID, withPreviousEventID(e4.event.ID))
|
||||
|
||||
storeIEs, _ = workflow2.GetRepository().ListInterruptEvents(t.Context(), exeIDInt)
|
||||
assert.Equal(t, 0, len(storeIEs))
|
||||
e5.assertSuccess()
|
||||
|
||||
outputMap := mustUnmarshalToMap(t, e5.output)
|
||||
|
||||
if !reflect.DeepEqual(outputMap, map[string]any{
|
||||
"output": []any{"answer 1", "answer 2"},
|
||||
}) && !reflect.DeepEqual(outputMap, map[string]any{
|
||||
"output": []any{"answer 2", "answer 1"},
|
||||
}) {
|
||||
t.Errorf("output map not equal: %v", outputMap)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -6292,3 +6432,5 @@ func TestChatFlowRun(t *testing.T) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -19,6 +19,7 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -84,7 +85,7 @@ func AccessLogMW() app.HandlerFunc {
|
||||
func SetLogIDMW() app.HandlerFunc {
|
||||
return func(ctx context.Context, c *app.RequestContext) {
|
||||
logID := uuid.New().String()
|
||||
ctx = context.WithValue(ctx, "log-id", logID)
|
||||
ctx = context.WithValue(ctx, consts.CtxLogIDKey, logID)
|
||||
|
||||
c.Header("X-Log-ID", logID)
|
||||
c.Next(ctx)
|
||||
|
||||
@ -24,6 +24,9 @@ type ExecuteToolOption struct {
|
||||
ToolVersion string
|
||||
Operation *Openapi3Operation
|
||||
InvalidRespProcessStrategy InvalidResponseProcessStrategy
|
||||
|
||||
AgentID int64
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
type ExecuteToolOpt func(o *ExecuteToolOption)
|
||||
@ -65,3 +68,10 @@ func WithAutoGenRespSchema() ExecuteToolOpt {
|
||||
o.AutoGenRespSchema = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithPluginHTTPHeader(agentID, conversationID int64) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.AgentID = agentID
|
||||
o.ConversationID = conversationID
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,6 +112,8 @@ type ExecuteRequest struct {
|
||||
History []*schema.Message
|
||||
ResumeInfo *InterruptInfo
|
||||
PreCallTools []*agentrun.ToolsRetriever
|
||||
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
type AgentIdentity struct {
|
||||
|
||||
@ -51,6 +51,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
||||
@ -345,43 +346,40 @@ func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
||||
}
|
||||
|
||||
func initCodeRunner() coderunner.Runner {
|
||||
// 为了安全考虑,移除不安全的direct runner,强制使用sandbox
|
||||
getAndSplit := func(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
||||
case "sandbox":
|
||||
getAndSplit := func(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(v, ",")
|
||||
}
|
||||
return strings.Split(v, ",")
|
||||
config := &sandbox.Config{
|
||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
|
||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
|
||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
|
||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
|
||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
|
||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
|
||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||
TimeoutSeconds: 0,
|
||||
MemoryLimitMB: 0,
|
||||
}
|
||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
|
||||
config.TimeoutSeconds = f
|
||||
} else {
|
||||
config.TimeoutSeconds = 60.0
|
||||
}
|
||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
|
||||
config.MemoryLimitMB = mem
|
||||
} else {
|
||||
config.MemoryLimitMB = 100
|
||||
}
|
||||
return sandbox.NewRunner(config)
|
||||
default:
|
||||
return direct.NewRunner()
|
||||
}
|
||||
|
||||
// 使用安全的默认配置
|
||||
config := &sandbox.Config{
|
||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), // 默认为空,禁止环境变量访问
|
||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead), // 默认为空,禁止文件读取
|
||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), // 默认为空,禁止文件写入
|
||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet), // 默认为空,禁止网络访问
|
||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun), // 默认为空,禁止运行外部程序
|
||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), // 默认为空,禁止FFI调用
|
||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||
TimeoutSeconds: 0,
|
||||
MemoryLimitMB: 0,
|
||||
}
|
||||
|
||||
// 设置安全的超时时间,最大30秒
|
||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil && f > 0 && f <= 30 {
|
||||
config.TimeoutSeconds = f
|
||||
} else {
|
||||
config.TimeoutSeconds = 30.0 // 默认30秒超时
|
||||
}
|
||||
|
||||
// 设置安全的内存限制,最大100MB
|
||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil && mem > 0 && mem <= 100 {
|
||||
config.MemoryLimitMB = mem
|
||||
} else {
|
||||
config.MemoryLimitMB = 100 // 默认100MB内存限制
|
||||
}
|
||||
|
||||
return sandbox.NewRunner(config)
|
||||
}
|
||||
|
||||
func initOCR() ocr.OCR {
|
||||
@ -800,4 +798,4 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ type AgentRuntime struct {
|
||||
AgentVersion string
|
||||
UserID string
|
||||
AgentID int64
|
||||
ConversationId int64
|
||||
IsDraft bool
|
||||
SpaceID int64
|
||||
ConnectorID int64
|
||||
|
||||
@ -80,6 +80,8 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, agentRuntim
|
||||
}
|
||||
}),
|
||||
ResumeInfo: agentRuntime.ResumeInfo,
|
||||
|
||||
ConversationID: agentRuntime.ConversationId,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -42,6 +42,8 @@ type Config struct {
|
||||
ModelMgr modelmgr.Manager
|
||||
ModelFactory chatmodel.Factory
|
||||
CPStore compose.CheckPointStore
|
||||
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
const (
|
||||
@ -108,6 +110,8 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||
userID: conf.UserID,
|
||||
agentIdentity: conf.Identity,
|
||||
toolConf: conf.Agent.Plugin,
|
||||
|
||||
conversationID: conf.ConversationID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -38,6 +38,8 @@ type toolConfig struct {
|
||||
userID string
|
||||
agentIdentity *entity.AgentIdentity
|
||||
toolConf []*bot_common.PluginInfo
|
||||
|
||||
conversationID int64
|
||||
}
|
||||
|
||||
func newPluginTools(ctx context.Context, conf *toolConfig) ([]tool.InvokableTool, error) {
|
||||
@ -71,6 +73,9 @@ func newPluginTools(ctx context.Context, conf *toolConfig) ([]tool.InvokableTool
|
||||
isDraft: conf.agentIdentity.IsDraft,
|
||||
projectInfo: projectInfo,
|
||||
toolInfo: ti,
|
||||
|
||||
agentID: conf.agentIdentity.AgentID,
|
||||
conversationID: conf.conversationID,
|
||||
})
|
||||
}
|
||||
|
||||
@ -82,6 +87,9 @@ type pluginInvokableTool struct {
|
||||
isDraft bool
|
||||
toolInfo *pluginEntity.ToolInfo
|
||||
projectInfo *plugin.ProjectInfo
|
||||
|
||||
agentID int64
|
||||
conversationID int64
|
||||
}
|
||||
|
||||
func (p *pluginInvokableTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
@ -124,6 +132,7 @@ func (p *pluginInvokableTool) InvokableRun(ctx context.Context, argumentsInJSON
|
||||
plugin.WithInvalidRespProcessStrategy(plugin.InvalidResponseProcessStrategyOfReturnDefault),
|
||||
plugin.WithToolVersion(p.toolInfo.GetVersion()),
|
||||
plugin.WithProjectInfo(p.projectInfo),
|
||||
plugin.WithPluginHTTPHeader(p.agentID, p.conversationID),
|
||||
}
|
||||
|
||||
resp, err := crossplugin.DefaultSVC().ExecuteTool(ctx, req, opts...)
|
||||
|
||||
@ -111,6 +111,8 @@ func (s *singleAgentImpl) StreamExecute(ctx context.Context, req *entity.Execute
|
||||
ModelMgr: s.ModelMgr,
|
||||
ModelFactory: s.ModelFactory,
|
||||
CPStore: s.CPStore,
|
||||
|
||||
ConversationID: req.ConversationID,
|
||||
}
|
||||
rn, err := agentflow.BuildAgent(ctx, conf)
|
||||
if err != nil {
|
||||
|
||||
@ -49,6 +49,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I
|
||||
AgentID: art.GetRunMeta().AgentID,
|
||||
IsDraft: art.GetRunMeta().IsDraft,
|
||||
UserID: art.GetRunMeta().UserID,
|
||||
ConversationId: art.GetRunMeta().ConversationID,
|
||||
ConnectorID: art.GetRunMeta().ConnectorID,
|
||||
PreRetrieveTools: art.GetRunMeta().PreRetrieveTools,
|
||||
Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0],
|
||||
|
||||
@ -21,6 +21,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -123,6 +125,8 @@ func (p *pluginServiceImpl) buildToolExecutor(ctx context.Context, req *ExecuteT
|
||||
impl = &toolExecutor{
|
||||
execScene: req.ExecScene,
|
||||
userID: req.UserID,
|
||||
agentID: execOpt.AgentID,
|
||||
conversationID: execOpt.ConversationID,
|
||||
plugin: pl,
|
||||
tool: tl,
|
||||
projectInfo: execOpt.ProjectInfo,
|
||||
@ -457,10 +461,13 @@ type ExecuteResponse struct {
|
||||
}
|
||||
|
||||
type toolExecutor struct {
|
||||
execScene model.ExecuteScene
|
||||
userID string
|
||||
plugin *entity.PluginInfo
|
||||
tool *entity.ToolInfo
|
||||
execScene model.ExecuteScene
|
||||
userID string
|
||||
agentID int64
|
||||
conversationID int64
|
||||
|
||||
plugin *entity.PluginInfo
|
||||
tool *entity.ToolInfo
|
||||
|
||||
projectInfo *entity.ProjectInfo
|
||||
invalidRespProcessStrategy model.InvalidResponseProcessStrategy
|
||||
@ -738,6 +745,12 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logId, _ := ctx.Value(consts.CtxLogIDKey).(string)
|
||||
header.Set("X-Tt-Logid", logId)
|
||||
header.Set("X-Aiplugin-Connector-Identifier", t.userID)
|
||||
header.Set("X-AIPlugin-Bot-ID", conv.Int64ToStr(t.agentID))
|
||||
header.Set("X-AIPlugin-Conversation-ID", conv.Int64ToStr(t.conversationID))
|
||||
|
||||
httpReq.Header = header
|
||||
|
||||
if len(bodyBytes) > 0 {
|
||||
|
||||
@ -35,6 +35,7 @@ type Executable interface {
|
||||
AsyncExecute(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (int64, error)
|
||||
AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error)
|
||||
AsyncResume(ctx context.Context, req *entity.ResumeRequest, config workflowModel.ExecuteConfig) error
|
||||
SyncResume(ctx context.Context, req *entity.ResumeRequest, config workflowModel.ExecuteConfig) (*entity.WorkflowExecution, vo.TerminatePlan, error)
|
||||
StreamExecute(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*schema.StreamReader[*entity.Message], error)
|
||||
StreamResume(ctx context.Context, req *entity.ResumeRequest, config workflowModel.ExecuteConfig) (
|
||||
*schema.StreamReader[*entity.Message], error)
|
||||
|
||||
@ -0,0 +1,209 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 180,
|
||||
"y": 13.700000000000003
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "http://10.37.46.247:9000/opencoze/default_icon/workflow_icon/icon-start.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250917%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250917T100753Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=cc901caf2ac7105e181a549398596a3d77a45b1735d75082da7924bcd6ee4a56",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1300,
|
||||
"y": 0.7000000000000028
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "http://10.37.46.247:9000/opencoze/default_icon/workflow_icon/icon-end.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250917%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250917T100753Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=61259d932108c153d5470e7fe21984f0961007508b01b1e884250744ede25dde",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "returnVariables",
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "list",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "105709",
|
||||
"name": "output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 99
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "105709",
|
||||
"type": "28",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 740,
|
||||
"y": 0
|
||||
},
|
||||
"canvasPosition": {
|
||||
"x": 560,
|
||||
"y": 293.4
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"title": "批处理",
|
||||
"icon": "http://10.37.46.247:9000/opencoze/default_icon/workflow_icon/icon-batch.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250917%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250917T100753Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=ea39a384a6d4ddb2dfe873ab7277b996747e42e129da23367c62068bf8980068",
|
||||
"description": "通过设定批量运行次数和逻辑,运行批处理体内的任务",
|
||||
"mainColor": "#00B2B2",
|
||||
"subTitle": "批处理"
|
||||
},
|
||||
"inputs": {
|
||||
"concurrentSize": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "1"
|
||||
}
|
||||
},
|
||||
"batchSize": {
|
||||
"type": "integer",
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "100"
|
||||
}
|
||||
},
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "input",
|
||||
"input": {
|
||||
"type": "list",
|
||||
"value": {
|
||||
"type": "literal",
|
||||
"content": "[\"12\",\"2\"]",
|
||||
"rawMeta": {
|
||||
"type": 99
|
||||
}
|
||||
},
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "list",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "136577",
|
||||
"name": "input"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"blocks": [
|
||||
{
|
||||
"id": "136577",
|
||||
"type": "30",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 180,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"nodeMeta": {
|
||||
"title": "输入",
|
||||
"icon": "http://10.37.46.247:9000/opencoze/default_icon/workflow_icon/icon-input.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=minioadmin%2F20250917%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250917T100753Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=d0af4e55a091037e82c025f29e925fbdb0592ca1319d03948a1833b5559ab2c4",
|
||||
"description": "支持中间过程的信息输入",
|
||||
"mainColor": "#5C62FF",
|
||||
"subTitle": "输入"
|
||||
},
|
||||
"inputs": {
|
||||
"outputSchema": "[{\"type\":\"string\",\"name\":\"input\",\"required\":true}]"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "105709",
|
||||
"targetNodeID": "136577",
|
||||
"sourcePortID": "batch-function-inline-output"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "136577",
|
||||
"targetNodeID": "105709",
|
||||
"targetPortID": "batch-function-inline-input"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "105709"
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "105709",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": "batch-output"
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -19,6 +19,7 @@ package compose
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -257,7 +258,7 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
|
||||
|
||||
if interruptEvent == nil {
|
||||
var logID string
|
||||
logID, _ = ctx.Value("log-id").(string)
|
||||
logID, _ = ctx.Value(consts.CtxLogIDKey).(string)
|
||||
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: executeID,
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -113,7 +114,7 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
|
||||
if parentNodeID != nil { // root workflow execution has already been created
|
||||
var logID string
|
||||
logID, _ = ctx.Value("log-id").(string)
|
||||
logID, _ = ctx.Value(consts.CtxLogIDKey).(string)
|
||||
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
@ -314,44 +315,42 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
}
|
||||
|
||||
// TODO: there maybe time gap here
|
||||
|
||||
if err := repo.SaveInterruptEvents(ctx, event.RootExecuteID, event.InterruptEvents); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save interrupt events: %v", err)
|
||||
}
|
||||
|
||||
if sw != nil && event.SubWorkflowCtx == nil { // only send interrupt event when is root workflow
|
||||
if event.SubWorkflowCtx == nil {
|
||||
firstIE, found, err := repo.GetFirstInterruptEvent(ctx, event.RootExecuteID)
|
||||
if err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to get first interrupt event: %v", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return noTerminate, fmt.Errorf("interrupt event does not exist, wfExeID: %d", event.RootExecuteID)
|
||||
}
|
||||
|
||||
nodeKey := firstIE.NodeKey
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.Answer,
|
||||
Content: firstIE.InterruptData, // TODO: may need to extract from InterruptData the actual info for user
|
||||
NodeID: string(nodeKey),
|
||||
NodeType: firstIE.NodeType,
|
||||
NodeTitle: firstIE.NodeTitle,
|
||||
Last: true,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
Status: entity.WorkflowInterrupted,
|
||||
InterruptEvent: firstIE,
|
||||
},
|
||||
}, nil)
|
||||
event.InterruptEvents = []*entity.InterruptEvent{firstIE}
|
||||
if sw != nil { // only send interrupt event when is root workflow
|
||||
nodeKey := firstIE.NodeKey
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.Answer,
|
||||
Content: firstIE.InterruptData, // TODO: may need to extract from InterruptData the actual info for user
|
||||
NodeID: string(nodeKey),
|
||||
NodeType: firstIE.NodeType,
|
||||
NodeTitle: firstIE.NodeTitle,
|
||||
Last: true,
|
||||
},
|
||||
}, nil)
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
Status: entity.WorkflowInterrupted,
|
||||
InterruptEvent: firstIE,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return workflowAbort, nil
|
||||
|
||||
@ -37,6 +37,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
wfschema "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
@ -906,6 +907,205 @@ func (i *impl) AsyncResume(ctx context.Context, req *entity.ResumeRequest, confi
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) SyncResume(ctx context.Context, req *entity.ResumeRequest, config workflowModel.ExecuteConfig) (*entity.WorkflowExecution, vo.TerminatePlan, error) {
|
||||
var err error
|
||||
wfExe, found, err := i.repo.GetWorkflowExecution(ctx, req.ExecuteID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, "", fmt.Errorf("workflow execution does not exist, id: %d", req.ExecuteID)
|
||||
}
|
||||
|
||||
if wfExe.RootExecutionID != wfExe.ID {
|
||||
return nil, "", fmt.Errorf("only root workflow can be resumed")
|
||||
}
|
||||
|
||||
if wfExe.Status != entity.WorkflowInterrupted {
|
||||
return nil, "", fmt.Errorf("workflow execution %d is not interrupted, status is %v, cannot resume", req.ExecuteID, wfExe.Status)
|
||||
}
|
||||
|
||||
var from workflowModel.Locator
|
||||
if wfExe.Version == "" {
|
||||
from = workflowModel.FromDraft
|
||||
} else {
|
||||
from = workflowModel.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfEntity, err := i.Get(ctx, &vo.GetPolicy{
|
||||
ID: wfExe.WorkflowID,
|
||||
QType: from,
|
||||
Version: wfExe.Version,
|
||||
CommitID: wfExe.CommitID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var canvas vo.Canvas
|
||||
err = sonic.UnmarshalString(wfEntity.Canvas, &canvas)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
config.From = from
|
||||
config.Version = wfExe.Version
|
||||
config.AppID = wfExe.AppID
|
||||
config.AgentID = wfExe.AgentID
|
||||
config.CommitID = wfExe.CommitID
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
if config.ConnectorID == 0 {
|
||||
config.ConnectorID = wfExe.ConnectorID
|
||||
}
|
||||
|
||||
var (
|
||||
lastEventChan <-chan *execute.Event
|
||||
startTime time.Time
|
||||
out map[string]any
|
||||
wf *compose.Workflow
|
||||
cancelCtx context.Context
|
||||
opts []einoCompose.Option
|
||||
nodeCount int32
|
||||
workflowSC *wfschema.WorkflowSchema
|
||||
)
|
||||
if wfExe.Mode == workflowModel.ExecuteModeNodeDebug {
|
||||
var nodeExes []*entity.NodeExecution
|
||||
nodeExes, err = i.repo.GetNodeExecutionsByWfExeID(ctx, wfExe.ID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if len(nodeExes) == 0 {
|
||||
return nil, "", fmt.Errorf("during node debug resume, no node execution found for workflow execution %d", wfExe.ID)
|
||||
}
|
||||
|
||||
var nodeID string
|
||||
for _, ne := range nodeExes {
|
||||
if ne.ParentNodeID == nil {
|
||||
nodeID = ne.NodeID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
workflowSC, err = adaptor.WorkflowSchemaFromNode(ctx, &canvas, nodeID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
nodeCount = workflowSC.NodeCount()
|
||||
wf, err = compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID),
|
||||
einoCompose.WithGraphName(fmt.Sprintf("%d", wfExe.WorkflowID)))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create workflow: %w", err)
|
||||
}
|
||||
|
||||
config.Mode = workflowModel.ExecuteModeNodeDebug
|
||||
|
||||
cancelCtx, _, opts, lastEventChan, err = compose.NewWorkflowRunner(
|
||||
wfEntity.GetBasic(), workflowSC, config, compose.WithResumeReq(req)).Prepare(ctx)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
startTime = time.Now()
|
||||
out, err = wf.SyncRun(cancelCtx, nil, opts...)
|
||||
|
||||
} else {
|
||||
workflowSC, err = adaptor.CanvasToWorkflowSchema(ctx, &canvas)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
nodeCount = workflowSC.NodeCount()
|
||||
var wfOpts []compose.WorkflowOption
|
||||
wfOpts = append(wfOpts, compose.WithIDAsName(wfExe.WorkflowID))
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
wfOpts = append(wfOpts, compose.WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
}
|
||||
|
||||
wf, err = compose.NewWorkflow(ctx, workflowSC, wfOpts...)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create workflow: %w", err)
|
||||
}
|
||||
|
||||
cancelCtx, _, opts, lastEventChan, err = compose.NewWorkflowRunner(
|
||||
wfEntity.GetBasic(), workflowSC, config, compose.WithResumeReq(req)).Prepare(ctx)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
startTime = time.Now()
|
||||
|
||||
out, err = wf.SyncRun(cancelCtx, nil, opts...)
|
||||
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if _, ok := einoCompose.ExtractInterruptInfo(err); !ok {
|
||||
var wfe vo.WorkflowError
|
||||
if errors.As(err, &wfe) {
|
||||
return nil, "", wfe.AppendDebug(req.ExecuteID, wfEntity.SpaceID, wfEntity.ID)
|
||||
} else {
|
||||
return nil, "", vo.WrapWithDebug(errno.ErrWorkflowExecuteFail, err, req.ExecuteID, wfEntity.SpaceID, wfEntity.ID, errorx.KV("cause", err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lastEvent := <-lastEventChan
|
||||
updateTime := time.Now()
|
||||
|
||||
var outStr string
|
||||
if wf.TerminatePlan() == vo.ReturnVariables {
|
||||
outStr, err = sonic.MarshalString(out)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
} else {
|
||||
outStr = out["output"].(string)
|
||||
}
|
||||
|
||||
var status entity.WorkflowExecuteStatus
|
||||
switch lastEvent.Type {
|
||||
case execute.WorkflowSuccess:
|
||||
status = entity.WorkflowSuccess
|
||||
case execute.WorkflowInterrupt:
|
||||
status = entity.WorkflowInterrupted
|
||||
case execute.WorkflowFailed:
|
||||
status = entity.WorkflowFailed
|
||||
case execute.WorkflowCancel:
|
||||
status = entity.WorkflowCancel
|
||||
}
|
||||
|
||||
var failReason *string
|
||||
if lastEvent.Err != nil {
|
||||
failReason = ptr.Of(lastEvent.Err.Error())
|
||||
}
|
||||
|
||||
return &entity.WorkflowExecution{
|
||||
ID: req.ExecuteID,
|
||||
WorkflowID: wfEntity.ID,
|
||||
Version: wfEntity.GetVersion(),
|
||||
SpaceID: wfEntity.SpaceID,
|
||||
ExecuteConfig: config,
|
||||
CreatedAt: startTime,
|
||||
NodeCount: nodeCount,
|
||||
Status: status,
|
||||
Duration: lastEvent.Duration,
|
||||
Input: ptr.Of(req.ResumeData),
|
||||
Output: ptr.Of(outStr),
|
||||
ErrorCode: ptr.Of("-1"),
|
||||
FailReason: failReason,
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: lastEvent.GetInputTokens(),
|
||||
OutputTokens: lastEvent.GetOutputTokens(),
|
||||
},
|
||||
UpdatedAt: ptr.Of(updateTime),
|
||||
RootExecutionID: req.ExecuteID,
|
||||
InterruptEvents: lastEvent.InterruptEvents,
|
||||
}, wf.TerminatePlan(), nil
|
||||
|
||||
}
|
||||
|
||||
// StreamResume resumes a workflow execution, using the passed in executionID and eventID.
|
||||
// Intermediate results during the resuming run are emitted using the returned StreamReader.
|
||||
// Caller is expected to poll the execution status using the GetExecution method.
|
||||
|
||||
@ -310,7 +310,6 @@ require (
|
||||
github.com/mtibben/percent v0.2.1 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
|
||||
golang.org/x/term v0.32.0 // indirect
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -48,7 +47,6 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
node := mdParser.Parse(text.NewReader(b))
|
||||
cs := config.ChunkingStrategy
|
||||
ps := config.ParsingStrategy
|
||||
@ -103,118 +101,13 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
||||
return text
|
||||
}
|
||||
|
||||
// validateImageURL 验证图片URL的安全性
|
||||
validateImageURL := func(urlString string) error {
|
||||
parsedURL, err := url.Parse(urlString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 只允许HTTP/HTTPS
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
// 检查域名白名单
|
||||
allowedDomains := []string{
|
||||
"images.unsplash.com",
|
||||
"cdn.example.com",
|
||||
"github.com",
|
||||
"githubusercontent.com",
|
||||
// 可以根据需要添加其他受信任的域名
|
||||
}
|
||||
|
||||
hostname := parsedURL.Hostname()
|
||||
for _, domain := range allowedDomains {
|
||||
if hostname == domain || strings.HasSuffix(hostname, "."+domain) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("domain not allowed: %s", hostname)
|
||||
}
|
||||
|
||||
// isPrivateIPAddress 检查IP地址是否为私有地址
|
||||
isPrivateIPAddress := func(ip net.IP) bool {
|
||||
// 检查私有IP范围
|
||||
privateRanges := []struct {
|
||||
cidr string
|
||||
}{
|
||||
{"10.0.0.0/8"},
|
||||
{"172.16.0.0/12"},
|
||||
{"192.168.0.0/16"},
|
||||
{"127.0.0.0/8"},
|
||||
{"169.254.0.0/16"}, // 链路本地地址
|
||||
{"::1/128"}, // IPv6 loopback
|
||||
{"fc00::/7"}, // IPv6 私有地址
|
||||
}
|
||||
|
||||
for _, r := range privateRanges {
|
||||
_, cidr, _ := net.ParseCIDR(r.cidr)
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateIP 检查是否为私有IP地址
|
||||
isPrivateIP := func(host string) bool {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
// 可能是域名,需要解析
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return true // 解析失败,拒绝访问
|
||||
}
|
||||
|
||||
// 检查所有解析的IP
|
||||
for _, resolvedIP := range ips {
|
||||
if isPrivateIPAddress(resolvedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return isPrivateIPAddress(ip)
|
||||
}
|
||||
|
||||
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
|
||||
// URL验证
|
||||
if err := validateImageURL(url); err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// 使用安全的HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 禁止访问私有IP
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isPrivateIP(host) {
|
||||
return nil, fmt.Errorf("access to private IP denied: %s", host)
|
||||
}
|
||||
|
||||
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// 添加安全头
|
||||
req.Header.Set("User-Agent", "CozeStudio/1.0")
|
||||
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download image: %w", err)
|
||||
@ -225,11 +118,7 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
||||
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 限制响应大小
|
||||
const maxImageSize = 10 * 1024 * 1024 // 10MB
|
||||
limitedReader := io.LimitReader(resp.Body, maxImageSize)
|
||||
|
||||
data, err := io.ReadAll(limitedReader)
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read image content: %w", err)
|
||||
}
|
||||
|
||||
@ -646,15 +646,15 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques
|
||||
var processedParams []interface{}
|
||||
var err error
|
||||
|
||||
// 禁用原始SQL执行以防止SQL注入攻击
|
||||
// Handle SQLType: if raw, do not process params
|
||||
if req.SQLType == entity2.SQLType_Raw {
|
||||
return nil, fmt.Errorf("raw SQL execution is not allowed for security reasons")
|
||||
}
|
||||
|
||||
// 强制使用参数化查询
|
||||
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
||||
processedSQL = req.SQL
|
||||
processedParams = nil
|
||||
} else {
|
||||
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
||||
@ -1011,4 +1011,4 @@ func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (s
|
||||
return whereClause.String(), values, nil
|
||||
}
|
||||
return "", values, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -350,9 +350,9 @@ func (mr *MockServiceMockRecorder) GetConvRelatedInfo(ctx, convID any) *gomock.C
|
||||
}
|
||||
|
||||
// GetConversationNameByID mocks base method.
|
||||
func (m *MockService) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
func (m *MockService) GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetConversationNameByID", ctx, env, appID, connectorID, conversationID)
|
||||
ret := m.ctrl.Call(m, "GetConversationNameByID", ctx, env, bizID, connectorID, conversationID)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
@ -360,9 +360,9 @@ func (m *MockService) GetConversationNameByID(ctx context.Context, env vo.Env, a
|
||||
}
|
||||
|
||||
// GetConversationNameByID indicates an expected call of GetConversationNameByID.
|
||||
func (mr *MockServiceMockRecorder) GetConversationNameByID(ctx, env, appID, connectorID, conversationID any) *gomock.Call {
|
||||
func (mr *MockServiceMockRecorder) GetConversationNameByID(ctx, env, bizID, connectorID, conversationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConversationNameByID", reflect.TypeOf((*MockService)(nil).GetConversationNameByID), ctx, env, appID, connectorID, conversationID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConversationNameByID", reflect.TypeOf((*MockService)(nil).GetConversationNameByID), ctx, env, bizID, connectorID, conversationID)
|
||||
}
|
||||
|
||||
// GetDynamicConversationByName mocks base method.
|
||||
@ -446,9 +446,9 @@ func (mr *MockServiceMockRecorder) GetNodeExecution(ctx, exeID, nodeID any) *gom
|
||||
}
|
||||
|
||||
// GetOrCreateConversation mocks base method.
|
||||
func (m *MockService) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
func (m *MockService) GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrCreateConversation", ctx, env, appID, connectorID, userID, conversationName)
|
||||
ret := m.ctrl.Call(m, "GetOrCreateConversation", ctx, env, bizID, connectorID, userID, conversationName)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
@ -456,9 +456,9 @@ func (m *MockService) GetOrCreateConversation(ctx context.Context, env vo.Env, a
|
||||
}
|
||||
|
||||
// GetOrCreateConversation indicates an expected call of GetOrCreateConversation.
|
||||
func (mr *MockServiceMockRecorder) GetOrCreateConversation(ctx, env, appID, connectorID, userID, conversationName any) *gomock.Call {
|
||||
func (mr *MockServiceMockRecorder) GetOrCreateConversation(ctx, env, bizID, connectorID, userID, conversationName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrCreateConversation", reflect.TypeOf((*MockService)(nil).GetOrCreateConversation), ctx, env, appID, connectorID, userID, conversationName)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrCreateConversation", reflect.TypeOf((*MockService)(nil).GetOrCreateConversation), ctx, env, bizID, connectorID, userID, conversationName)
|
||||
}
|
||||
|
||||
// GetTemplateByName mocks base method.
|
||||
@ -774,6 +774,22 @@ func (mr *MockServiceMockRecorder) SyncRelatedWorkflowResources(ctx, appID, rela
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncRelatedWorkflowResources", reflect.TypeOf((*MockService)(nil).SyncRelatedWorkflowResources), ctx, appID, relatedWorkflows, related)
|
||||
}
|
||||
|
||||
// SyncResume mocks base method.
|
||||
func (m *MockService) SyncResume(ctx context.Context, req *entity.ResumeRequest, arg2 workflow.ExecuteConfig) (*entity.WorkflowExecution, vo.TerminatePlan, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncResume", ctx, req, arg2)
|
||||
ret0, _ := ret[0].(*entity.WorkflowExecution)
|
||||
ret1, _ := ret[1].(vo.TerminatePlan)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// SyncResume indicates an expected call of SyncResume.
|
||||
func (mr *MockServiceMockRecorder) SyncResume(ctx, req, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncResume", reflect.TypeOf((*MockService)(nil).SyncResume), ctx, req, arg2)
|
||||
}
|
||||
|
||||
// UpdateChatFlowRole mocks base method.
|
||||
func (m *MockService) UpdateChatFlowRole(ctx context.Context, workflowID int64, role *vo.ChatFlowRoleUpdate) error {
|
||||
m.ctrl.T.Helper()
|
||||
@ -1329,9 +1345,9 @@ func (mr *MockRepositoryMockRecorder) GetDraftWorkflowsByAppID(ctx, AppID any) *
|
||||
}
|
||||
|
||||
// GetDynamicConversationByID mocks base method.
|
||||
func (m *MockRepository) GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
|
||||
func (m *MockRepository) GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDynamicConversationByID", ctx, env, appID, connectorID, conversationID)
|
||||
ret := m.ctrl.Call(m, "GetDynamicConversationByID", ctx, env, bizID, connectorID, conversationID)
|
||||
ret0, _ := ret[0].(*entity.DynamicConversation)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
@ -1339,9 +1355,9 @@ func (m *MockRepository) GetDynamicConversationByID(ctx context.Context, env vo.
|
||||
}
|
||||
|
||||
// GetDynamicConversationByID indicates an expected call of GetDynamicConversationByID.
|
||||
func (mr *MockRepositoryMockRecorder) GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID any) *gomock.Call {
|
||||
func (mr *MockRepositoryMockRecorder) GetDynamicConversationByID(ctx, env, bizID, connectorID, conversationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDynamicConversationByID", reflect.TypeOf((*MockRepository)(nil).GetDynamicConversationByID), ctx, env, appID, connectorID, conversationID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDynamicConversationByID", reflect.TypeOf((*MockRepository)(nil).GetDynamicConversationByID), ctx, env, bizID, connectorID, conversationID)
|
||||
}
|
||||
|
||||
// GetDynamicConversationByName mocks base method.
|
||||
@ -1565,9 +1581,9 @@ func (mr *MockRepositoryMockRecorder) GetOrCreateStaticConversation(ctx, env, id
|
||||
}
|
||||
|
||||
// GetStaticConversationByID mocks base method.
|
||||
func (m *MockRepository) GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
func (m *MockRepository) GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStaticConversationByID", ctx, env, appID, connectorID, conversationID)
|
||||
ret := m.ctrl.Call(m, "GetStaticConversationByID", ctx, env, bizID, connectorID, conversationID)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
@ -1575,9 +1591,9 @@ func (m *MockRepository) GetStaticConversationByID(ctx context.Context, env vo.E
|
||||
}
|
||||
|
||||
// GetStaticConversationByID indicates an expected call of GetStaticConversationByID.
|
||||
func (mr *MockRepositoryMockRecorder) GetStaticConversationByID(ctx, env, appID, connectorID, conversationID any) *gomock.Call {
|
||||
func (mr *MockRepositoryMockRecorder) GetStaticConversationByID(ctx, env, bizID, connectorID, conversationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaticConversationByID", reflect.TypeOf((*MockRepository)(nil).GetStaticConversationByID), ctx, env, appID, connectorID, conversationID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaticConversationByID", reflect.TypeOf((*MockRepository)(nil).GetStaticConversationByID), ctx, env, bizID, connectorID, conversationID)
|
||||
}
|
||||
|
||||
// GetStaticConversationByTemplateID mocks base method.
|
||||
|
||||
@ -19,6 +19,7 @@ package logs
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
@ -192,7 +193,7 @@ func (ll *defaultLogger) logfCtx(ctx context.Context, lv Level, format *string,
|
||||
return
|
||||
}
|
||||
msg := lv.toString()
|
||||
logID := ctx.Value("log-id")
|
||||
logID := ctx.Value(consts.CtxLogIDKey)
|
||||
if logID != nil {
|
||||
msg += fmt.Sprintf("[log-id: %v] ", logID)
|
||||
}
|
||||
|
||||
@ -99,6 +99,10 @@ const (
|
||||
PPStructureAPIURL = "PADDLEOCR_STRUCTURE_API_URL"
|
||||
)
|
||||
|
||||
const (
|
||||
CtxLogIDKey = "log-id"
|
||||
)
|
||||
|
||||
const (
|
||||
ShortcutCommandResourceType = "uri"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user