refactor(workflow): Calculate chat history rounds during schema convertion (#1990)

Co-authored-by: zhuangjie.1125 <zhuangjie.1125@bytedance.com>
This commit is contained in:
lvxinyu-1117
2025-09-10 15:52:23 +08:00
committed by GitHub
parent 4416127d47
commit 4bfce5a8cb
29 changed files with 2492 additions and 384 deletions

View File

@ -22,6 +22,8 @@ import (
"fmt"
"time"
"github.com/coze-dev/coze-studio/backend/types/consts"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
@ -282,6 +284,50 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
return executeID, nil
}
func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error {
if historyRounds <= 0 {
return nil
}
if shouldFetchConversationByName {
var cID, sID, bizID int64
var err error
if config.AppID != nil {
bizID = *config.AppID
} else if config.AgentID != nil {
bizID = *config.AgentID
}
for k, v := range input {
if k == vo.ConversationNameKey {
cName, ok := v.(string)
if !ok {
return errors.New("CONVERSATION_NAME must be string")
}
cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName)
if err != nil {
return err
}
config.ConversationID = ptr.Of(cID)
config.SectionID = ptr.Of(sID)
}
}
}
messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
return nil
}
func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) {
var (
err error
@ -308,30 +354,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
}
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
if err != nil {
return 0, err
}
}
if historyRounds > 0 {
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
}
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
@ -342,6 +364,17 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds = workflowSC.HistoryRounds()
}
if historyRounds > 0 {
if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil {
return 0, err
}
}
wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID)))
if err != nil {
return 0, fmt.Errorf("failed to create workflow: %w", err)
@ -417,29 +450,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
}
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
if err != nil {
return nil, err
}
}
if historyRounds > 0 {
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
}
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
@ -450,6 +460,17 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds = workflowSC.HistoryRounds()
}
if historyRounds > 0 {
if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil {
return nil, err
}
}
var wfOpts []compose.WorkflowOption
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
@ -997,20 +1018,6 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
return nil
}
const maxHistoryRounds int64 = 30
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
if wfEntity == nil {
return 0, nil
}
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
if err != nil {
return 0, err
}
return min(maxRounds, maxHistoryRounds), nil
}
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
convID := config.ConversationID
agentID := config.AgentID
@ -1027,11 +1034,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
return nil, nil, nil
}
var resolvedAppID int64
var bizID int64
if appID != nil {
resolvedAppID = *appID
bizID = *appID
} else if agentID != nil {
resolvedAppID = *agentID
bizID = *agentID
} else {
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
return nil, nil, nil
@ -1039,7 +1046,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
ConversationID: *convID,
AppID: resolvedAppID,
BizID: bizID,
UserID: userID,
Rounds: historyRounds + 1,
SectionID: *sectionID,
@ -1048,7 +1055,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
if err != nil {
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
return nil, nil, nil
return nil, nil, err
}
if len(runIds) <= 1 {
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
@ -1061,7 +1068,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
})
if err != nil {
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
return nil, nil, nil
return nil, nil, err
}
return response.Messages, response.SchemaMessages, nil