Compare commits

...

4 Commits

2 changed files with 93 additions and 37 deletions

View File

@ -186,6 +186,11 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
logs.CtxInfof(ctx, "info-OnEndWithStreamOutput, info=%v, output=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(output))
switch info.Component {
case compose.ComponentOfGraph, components.ComponentOfChatModel:
if info.Name == keyOfReActAgent {
r.processToolsReturnDirectlyStreamWithLazyInit(ctx, output)
return ctx
}
if info.Name != keyOfReActAgentChatModel && info.Name != keyOfLLM {
output.Close()
return ctx
@ -326,6 +331,56 @@ func convToolsNodeCallbackInput(input callbacks.CallbackInput) *schema.Message {
return nil
}
}
func convToolsNodeCallbackOutputMessage(output callbacks.CallbackOutput) *schema.Message {
switch t := output.(type) {
case *schema.Message:
return t
default:
return nil
}
}
func (r *replyChunkCallback) processToolsReturnDirectlyStreamWithLazyInit(_ context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) {
var streamInitialized bool
var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message]
for {
cbOut, err := output.Recv()
if errors.Is(err, io.EOF) {
if sw != nil {
sw.Close()
}
break
}
if err != nil {
if sw != nil {
sw.Send(nil, err)
sw.Close()
}
break
}
msg := convToolsNodeCallbackOutputMessage(cbOut)
if msg == nil {
break
}
if msg.Role != schema.Tool {
break
}
if msg.Role == schema.Tool {
if !streamInitialized {
sr, sw = schema.Pipe[*schema.Message](5)
r.sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfChatModelAnswer,
ChatModelAnswer: sr,
}, nil)
streamInitialized = true
}
sw.Send(msg, nil)
}
}
}
func convToolsNodeCallbackOutput(output callbacks.CallbackOutput) []*schema.Message {
switch t := output.(type) {

View File

@ -496,48 +496,49 @@ func (w *WorkflowHandler) OnEndWithStreamOutput(ctx context.Context, info *callb
return ctx
}
// consumes the stream synchronously because the Exit node has already processed this stream synchronously.
defer output.Close()
fullOutput := make(map[string]any)
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
break
safego.Go(ctx, func() {
defer output.Close()
fullOutput := make(map[string]any)
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
break
}
if _, ok := schema.GetSourceName(e); ok {
continue
}
logs.Errorf("workflow OnEndWithStreamOutput failed to receive stream output: %v", e)
_ = w.OnError(ctx, info, e)
return
}
if _, ok := schema.GetSourceName(e); ok {
continue
fullOutput, e = nodes.ConcatTwoMaps(fullOutput, chunk.(map[string]any))
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
return
}
logs.Errorf("workflow OnEndWithStreamOutput failed to receive stream output: %v", e)
_ = w.OnError(ctx, info, e)
return ctx
}
fullOutput, e = nodes.ConcatTwoMaps(fullOutput, chunk.(map[string]any))
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
return ctx
}
}
c := GetExeCtx(ctx)
e := &Event{
Type: WorkflowSuccess,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Output: fullOutput,
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
c := GetExeCtx(ctx)
e := &Event{
Type: WorkflowSuccess,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Output: fullOutput,
}
}
w.ch <- e
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
w.ch <- e
})
return ctx
}