refactor(workflow): Calculate chat history rounds during schema convertion (#1990)
Co-authored-by: zhuangjie.1125 <zhuangjie.1125@bytedance.com>
This commit is contained in:
@ -22,6 +22,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
@ -282,6 +284,50 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
return executeID, nil
|
||||
}
|
||||
|
||||
func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error {
|
||||
if historyRounds <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if shouldFetchConversationByName {
|
||||
var cID, sID, bizID int64
|
||||
var err error
|
||||
if config.AppID != nil {
|
||||
bizID = *config.AppID
|
||||
} else if config.AgentID != nil {
|
||||
bizID = *config.AgentID
|
||||
}
|
||||
for k, v := range input {
|
||||
if k == vo.ConversationNameKey {
|
||||
cName, ok := v.(string)
|
||||
if !ok {
|
||||
return errors.New("CONVERSATION_NAME must be string")
|
||||
}
|
||||
cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.ConversationID = ptr.Of(cID)
|
||||
config.SectionID = ptr.Of(sID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) {
|
||||
var (
|
||||
err error
|
||||
@ -308,30 +354,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
|
||||
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@ -342,6 +364,17 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds = workflowSC.HistoryRounds()
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID)))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create workflow: %w", err)
|
||||
@ -417,29 +450,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@ -450,6 +460,17 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds = workflowSC.HistoryRounds()
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var wfOpts []compose.WorkflowOption
|
||||
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
@ -997,20 +1018,6 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxHistoryRounds int64 = 30
|
||||
|
||||
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return min(maxRounds, maxHistoryRounds), nil
|
||||
}
|
||||
|
||||
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
|
||||
convID := config.ConversationID
|
||||
agentID := config.AgentID
|
||||
@ -1027,11 +1034,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
var resolvedAppID int64
|
||||
var bizID int64
|
||||
if appID != nil {
|
||||
resolvedAppID = *appID
|
||||
bizID = *appID
|
||||
} else if agentID != nil {
|
||||
resolvedAppID = *agentID
|
||||
bizID = *agentID
|
||||
} else {
|
||||
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
|
||||
return nil, nil, nil
|
||||
@ -1039,7 +1046,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
|
||||
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: *convID,
|
||||
AppID: resolvedAppID,
|
||||
BizID: bizID,
|
||||
UserID: userID,
|
||||
Rounds: historyRounds + 1,
|
||||
SectionID: *sectionID,
|
||||
@ -1048,7 +1055,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
|
||||
return nil, nil, nil
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(runIds) <= 1 {
|
||||
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
|
||||
@ -1061,7 +1068,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
|
||||
return nil, nil, nil
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return response.Messages, response.SchemaMessages, nil
|
||||
|
||||
Reference in New Issue
Block a user