diff --git a/backend/domain/conversation/agentrun/service/agent_run_impl.go b/backend/domain/conversation/agentrun/service/agent_run_impl.go index 277c06711..cf82c3405 100644 --- a/backend/domain/conversation/agentrun/service/agent_run_impl.go +++ b/backend/domain/conversation/agentrun/service/agent_run_impl.go @@ -73,6 +73,37 @@ type runtimeDependence struct { usage *agentrun.Usage } +func (rd *runtimeDependence) SetRunID(runID int64) { + rd.runID = runID +} +func (rd *runtimeDependence) GetRunID() int64 { + return rd.runID +} +func (rd *runtimeDependence) SetRunMeta(arm *entity.AgentRunMeta) { + rd.runMeta = arm +} +func (rd *runtimeDependence) GetRunMeta() *entity.AgentRunMeta { + return rd.runMeta +} +func (rd *runtimeDependence) SetAgentInfo(agentInfo *singleagent.SingleAgent) { + rd.agentInfo = agentInfo +} +func (rd *runtimeDependence) GetAgentInfo() *singleagent.SingleAgent { + return rd.agentInfo +} +func (rd *runtimeDependence) SetQuestionMsgID(msgID int64) { + rd.questionMsgID = msgID +} +func (rd *runtimeDependence) GetQuestionMsgID() int64 { + return rd.questionMsgID +} +func (rd *runtimeDependence) SetStartTime(t time.Time) { + rd.startTime = t +} +func (rd *runtimeDependence) GetStartTime() time.Time { + return rd.startTime +} + type Components struct { RunRecordRepo repository.RunRecordRepo ImagexSVC imagex.ImageX @@ -116,7 +147,7 @@ func (c *runImpl) run(ctx context.Context, sw *schema.StreamWriter[*entity.Agent return } - rtDependence.agentInfo = agentInfo + rtDependence.SetAgentInfo(agentInfo) history, err := c.handlerHistory(ctx, rtDependence) if err != nil { @@ -128,7 +159,7 @@ func (c *runImpl) run(ctx context.Context, sw *schema.StreamWriter[*entity.Agent if err != nil { return } - rtDependence.runID = runRecord.ID + rtDependence.SetRunID(runRecord.ID) defer func() { srRecord := c.buildSendRunRecord(ctx, runRecord, entity.RunStatusCompleted) if err != nil { @@ -147,9 +178,9 @@ func (c *runImpl) run(ctx context.Context, sw *schema.StreamWriter[*entity.Agent return } - rtDependence.questionMsgID = input.ID + rtDependence.SetQuestionMsgID(input.ID) - if rtDependence.agentInfo.BotMode == bot_common.BotMode_WorkflowMode { + if rtDependence.GetAgentInfo().BotMode == bot_common.BotMode_WorkflowMode { err = c.handlerWfAsAgentStreamExecute(ctx, sw, history, rtDependence) } else { err = c.handlerAgentStreamExecute(ctx, sw, history, input, rtDependence) @@ -189,7 +220,7 @@ func (c *runImpl) handlerWfAsAgentStreamExecute(ctx context.Context, sw *schema. if resumeInfo != nil { wfStreamer, err = crossworkflow.DefaultSVC().StreamResume(ctx, &crossworkflow.ResumeRequest{ ResumeData: concatWfInput(rtDependence), - EventID: 0, + EventID: resumeInfo.ChatflowInterrupt.InterruptEvent.ID, ExecuteID: resumeInfo.ChatflowInterrupt.ExecuteID, }, executeConfig) } else { @@ -520,13 +551,19 @@ func (c *runImpl) pullWfStream(ctx context.Context, events *schema.StreamReader[ st, re := events.Recv() if re != nil { if errors.Is(re, io.EOF) { + // update usage + + finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence) + if finishErr != nil { + logs.CtxErrorf(ctx, "handlerFinalAnswerFinish error: %v", finishErr) + return + } return } logs.CtxErrorf(ctx, "pullWfStream Recv error: %v", re) c.handlerErr(ctx, re, sw) return } - if st == nil { continue } @@ -536,7 +573,6 @@ func (c *runImpl) pullWfStream(ctx context.Context, events *schema.StreamReader[ OutputTokens: st.StateMessage.Usage.OutputTokens, TotalCount: st.StateMessage.Usage.InputTokens + st.StateMessage.Usage.OutputTokens, } - logs.CtxInfof(ctx, "pullWfStream usage:%v,err:%v", conv.DebugJsonToStr(usage), re) } if st.StateMessage != nil && st.StateMessage.InterruptEvent != nil { // interrupt @@ -625,13 +661,6 @@ func (c *runImpl) handlerWfInterruptMsg(ctx context.Context, sw *schema.StreamWr if err != nil { return } - - finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence) - if finishErr != nil { - logs.CtxErrorf(ctx, "handlerFinalAnswerFinish error: %v", finishErr) - return - } - } func (c *runImpl) handlerWfInterruptEvent(_ context.Context, interruptEventData *crossworkflow.InterruptEvent) (string, message.ContentType, error) { diff --git a/backend/domain/conversation/message/internal/dal/message.go b/backend/domain/conversation/message/internal/dal/message.go index 6e644231d..6621f50e9 100644 --- a/backend/domain/conversation/message/internal/dal/message.go +++ b/backend/domain/conversation/message/internal/dal/message.go @@ -145,6 +145,10 @@ func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Messa if err != nil { return 0, err } + if do.RowsAffected == 0 { + return 0, errorx.New(errno.ErrRecordNotFound) + } + return do.RowsAffected, nil } diff --git a/backend/types/errno/conversation.go b/backend/types/errno/conversation.go index 1d8bdbf3c..3ea2dfb64 100644 --- a/backend/types/errno/conversation.go +++ b/backend/types/errno/conversation.go @@ -35,9 +35,17 @@ const ( ErrConversationMessageNotFound = 103200001 ErrAgentRun = 103200002 + + ErrRecordNotFound = 103200003 ) func init() { + code.Register( + ErrRecordNotFound, + "record not found or nothing to update", + code.WithAffectStability(false), + ) + code.Register( ErrAgentRun, "Interal Server Error",