fix: update eino to fix the issue when field mappings are resolved re… (#1952)
This commit is contained in:
16
.github/workflows/ci@backend.yml
vendored
16
.github/workflows/ci@backend.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: coverage.out
|
||||
BREAKDOWN_FILE: main.breakdown
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
@ -52,7 +52,7 @@ jobs:
|
||||
mysql version: '8.4.5'
|
||||
mysql database: 'opencoze'
|
||||
mysql root password: 'root'
|
||||
|
||||
|
||||
- name: Verify MySQL Startup
|
||||
run: |
|
||||
echo "Waiting for MySQL to be ready..."
|
||||
@ -70,8 +70,12 @@ jobs:
|
||||
run: sudo apt-get update && sudo apt-get install -y mysql-client
|
||||
|
||||
- name: Initialize Database
|
||||
run: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql
|
||||
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 10
|
||||
max_attempts: 20
|
||||
command: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql
|
||||
|
||||
- name: Run Go Test
|
||||
run: |
|
||||
modules=`find . -name "go.mod" -exec dirname {} \;`
|
||||
@ -82,7 +86,7 @@ jobs:
|
||||
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
|
||||
go work sync
|
||||
go test -race -v -coverprofile=${{ env.COVERAGE_FILE }} -gcflags="all=-l -N" -coverpkg=$coverpkg $list
|
||||
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
@ -118,4 +122,4 @@ jobs:
|
||||
if [[ ! -f "go.work" ]];then go work init;fi
|
||||
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
|
||||
go work sync
|
||||
go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list
|
||||
go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list
|
||||
|
||||
@ -43,14 +43,15 @@ import (
|
||||
"github.com/cloudwego/hertz/pkg/common/ut"
|
||||
"github.com/cloudwego/hertz/pkg/protocol"
|
||||
"github.com/cloudwego/hertz/pkg/protocol/sse"
|
||||
message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
|
||||
message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
@ -755,6 +756,7 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro
|
||||
var nodeType string
|
||||
var token *workflow.TokenAndCost
|
||||
var reason string
|
||||
var count int
|
||||
for {
|
||||
if nodeEvent != nil {
|
||||
if options.previousInterruptEventID != "" {
|
||||
@ -770,6 +772,10 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro
|
||||
break
|
||||
}
|
||||
|
||||
if count > 1000 {
|
||||
r.t.Fatal("get process for too long")
|
||||
}
|
||||
|
||||
getProcessResp := getProcess(r.t, r.h, id, exeID)
|
||||
if len(getProcessResp.Data.NodeResults) == 1 {
|
||||
output = getProcessResp.Data.NodeResults[0].Output
|
||||
@ -803,6 +809,8 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro
|
||||
eventID = nodeEvent.ID
|
||||
}
|
||||
r.t.Logf("getProcess output= %s, status= %v, eventID= %s, nodeType= %s", output, workflowStatus, eventID, nodeType)
|
||||
|
||||
count++
|
||||
}
|
||||
|
||||
return &exeResult{
|
||||
@ -1624,6 +1632,7 @@ func TestNestedSubWorkflowWithInterrupt(t *testing.T) {
|
||||
post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
|
||||
WorkflowID: topID,
|
||||
})
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
midID := r.load("subworkflow/middle_workflow.json", withID(7494849202016272435))
|
||||
@ -1841,6 +1850,7 @@ func TestPublishWorkflow(t *testing.T) {
|
||||
WorkflowID: id,
|
||||
}
|
||||
_ = post[workflow.DeleteWorkflowResponse](r, deleteReq)
|
||||
time.Sleep(time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
@ -1964,6 +1974,7 @@ func TestSimpleInvokableToolWithReturnVariables(t *testing.T) {
|
||||
post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
|
||||
WorkflowID: id,
|
||||
})
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
@ -2628,6 +2639,7 @@ func TestListWorkflowAsToolData(t *testing.T) {
|
||||
WorkflowID: id,
|
||||
}
|
||||
_ = post[workflow.DeleteWorkflowResponse](r, deleteReq)
|
||||
time.Sleep(time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
@ -2662,6 +2674,7 @@ func TestWorkflowDetailAndDetailInfo(t *testing.T) {
|
||||
WorkflowID: id,
|
||||
}
|
||||
_ = post[workflow.DeleteWorkflowResponse](r, deleteReq)
|
||||
time.Sleep(time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
@ -4542,6 +4555,8 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
@ -85,7 +85,6 @@ type ToolResponseInfo struct {
|
||||
FunctionInfo
|
||||
CallID string
|
||||
Response string
|
||||
Complete bool
|
||||
}
|
||||
|
||||
type ToolType = workflow.PluginType
|
||||
|
||||
@ -38,7 +38,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, []einoCompose.Option, error) {
|
||||
func (r *WorkflowRunner) designateOptions(ctx context.Context) ([]einoCompose.Option, error) {
|
||||
var (
|
||||
wb = r.basic
|
||||
exeCfg = r.config
|
||||
@ -83,13 +83,13 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
ns,
|
||||
string(key))
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
return nil, err
|
||||
}
|
||||
opts = append(opts, subOpts...)
|
||||
} else if ns.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts = append(opts, llmNodeOpts...)
|
||||
@ -103,7 +103,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
ns,
|
||||
string(key))
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
return nil, err
|
||||
}
|
||||
for _, subO := range subOpts {
|
||||
opts = append(opts, WrapOpt(subO, parent.Key))
|
||||
@ -111,7 +111,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
} else if ns.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
return nil, err
|
||||
}
|
||||
for _, subO := range llmNodeOpts {
|
||||
opts = append(opts, WrapOpt(subO, parent.Key))
|
||||
@ -124,7 +124,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
|
||||
}
|
||||
|
||||
return ctx, opts, nil
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func nodeCallbackOption(key vo.NodeKey, name string, eventChan chan *execute.Event, resumeEvent *entity.InterruptEvent,
|
||||
|
||||
@ -92,6 +92,7 @@ func init() {
|
||||
_ = compose.RegisterSerializableType[*schema.Message]("schema_message")
|
||||
_ = compose.RegisterSerializableType[*crossmessage.WfMessage]("history_messages")
|
||||
_ = compose.RegisterSerializableType[*crossmessage.Content]("content")
|
||||
_ = compose.RegisterSerializableType[*model.PromptTokenDetails]("prompt_token_details")
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
|
||||
}()
|
||||
}
|
||||
|
||||
ctx, composeOpts, err := r.designateOptions(ctx)
|
||||
composeOpts, err := r.designateOptions(ctx)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, err
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@ -84,8 +85,10 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) (
|
||||
cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) {
|
||||
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest,
|
||||
argumentsInJSON string, opts ...tool.Option) (
|
||||
cancelCtx context.Context, executeID int64, input map[string]any,
|
||||
lastEventChan <-chan *execute.Event, callOpts []einoCompose.Option, err error) {
|
||||
cfg := execute.GetExecuteConfig(opts...)
|
||||
|
||||
var runOpts []WorkflowRunnerOption
|
||||
@ -126,11 +129,12 @@ func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
|
||||
cancelCtx, executeID, callOpts, lastEventChan, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (
|
||||
contentStr string, err error) {
|
||||
rInfo, allIEs := execute.GetResumeRequest(opts...)
|
||||
var (
|
||||
previouslyInterrupted bool
|
||||
@ -145,6 +149,18 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
}
|
||||
}
|
||||
|
||||
ctx = callbacks.OnStart(ctx, &tool.CallbackInput{
|
||||
ArgumentsInJSON: argumentsInJSON,
|
||||
Extra: map[string]any{
|
||||
execute.ToolCallIDKey: callID,
|
||||
},
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
|
||||
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
|
||||
return "", einoCompose.InterruptAndRerun
|
||||
@ -152,7 +168,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
|
||||
defer resumeOnce(rInfo, callID, allIEs)
|
||||
|
||||
cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
|
||||
cancelCtx, executeID, in, _, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -179,7 +195,19 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
}
|
||||
|
||||
if i.terminatePlan == vo.ReturnVariables {
|
||||
return sonic.MarshalString(out)
|
||||
contentStr, err = sonic.MarshalString(out)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_ = callbacks.OnEnd(ctx, &tool.CallbackOutput{
|
||||
Response: contentStr,
|
||||
Extra: map[string]any{
|
||||
execute.ToolCallIDKey: callID,
|
||||
},
|
||||
})
|
||||
|
||||
return contentStr, nil
|
||||
}
|
||||
|
||||
content, ok := out[answerKey]
|
||||
@ -187,7 +215,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
return "", fmt.Errorf("no answer found when terminate plan is use answer content. out: %v", out)
|
||||
}
|
||||
|
||||
contentStr, ok := content.(string)
|
||||
contentStr, ok = content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("answer content is not string. content: %v", content)
|
||||
}
|
||||
@ -196,6 +224,13 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
_ = callbacks.OnEnd(ctx, &tool.CallbackOutput{
|
||||
Response: contentStr,
|
||||
Extra: map[string]any{
|
||||
execute.ToolCallIDKey: callID,
|
||||
},
|
||||
})
|
||||
|
||||
return contentStr, nil
|
||||
}
|
||||
|
||||
@ -207,6 +242,10 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
|
||||
return i.wfEntity
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type streamableWorkflow struct {
|
||||
workflowTool
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
|
||||
@ -235,12 +274,14 @@ func (s *streamableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) {
|
||||
return s.info, nil
|
||||
}
|
||||
|
||||
func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) {
|
||||
func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (
|
||||
out *schema.StreamReader[string], err error) {
|
||||
rInfo, allIEs := execute.GetResumeRequest(opts...)
|
||||
var (
|
||||
previouslyInterrupted bool
|
||||
callID = einoCompose.GetToolCallID(ctx)
|
||||
previousExecuteID int64
|
||||
toolFinishChan = make(chan struct{})
|
||||
)
|
||||
for interruptedCallID := range allIEs {
|
||||
if callID == interruptedCallID {
|
||||
@ -250,6 +291,20 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
|
||||
}
|
||||
}
|
||||
|
||||
ctx = callbacks.OnStart(ctx, &tool.CallbackInput{
|
||||
ArgumentsInJSON: argumentsInJSON,
|
||||
Extra: map[string]any{
|
||||
execute.ToolCallIDKey: callID,
|
||||
execute.ToolFinishChanKey: toolFinishChan,
|
||||
},
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
close(toolFinishChan)
|
||||
}
|
||||
}()
|
||||
|
||||
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
|
||||
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
|
||||
return nil, einoCompose.InterruptAndRerun
|
||||
@ -257,7 +312,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
|
||||
|
||||
defer resumeOnce(rInfo, callID, allIEs)
|
||||
|
||||
cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
|
||||
cancelCtx, executeID, in, lastEventChan, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -283,22 +338,35 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return schema.StreamReaderWithConvert(outStream, func(in map[string]any) (string, error) {
|
||||
content, ok := in["output"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no output found when stream plan is use output content. out: %v", in)
|
||||
go func() {
|
||||
for range lastEventChan {
|
||||
}
|
||||
close(toolFinishChan)
|
||||
}()
|
||||
|
||||
contentStr, ok := content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("output content is not string. content: %v", content)
|
||||
}
|
||||
_, callbackStream := callbacks.OnEndWithStreamOutput(ctx, schema.StreamReaderWithConvert(outStream,
|
||||
func(in map[string]any) (*tool.CallbackOutput, error) {
|
||||
content, ok := in["output"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no output found when stream plan is use output content. out: %v", in)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
|
||||
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
|
||||
}
|
||||
contentStr, ok := content.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("output content is not string. content: %v", content)
|
||||
}
|
||||
|
||||
return contentStr, nil
|
||||
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
|
||||
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
return &tool.CallbackOutput{
|
||||
Response: contentStr,
|
||||
}, nil
|
||||
}))
|
||||
|
||||
return schema.StreamReaderWithConvert(callbackStream, func(in *tool.CallbackOutput) (string, error) {
|
||||
return in.Response, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -309,3 +377,7 @@ func (s *streamableWorkflow) TerminatePlan() vo.TerminatePlan {
|
||||
func (s *streamableWorkflow) GetWorkflow() *entity.Workflow {
|
||||
return s.wfEntity
|
||||
}
|
||||
|
||||
func (s *streamableWorkflow) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -370,10 +370,6 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
|
||||
interruptEvent.EventType, interruptEvent.NodeKey)
|
||||
}
|
||||
|
||||
if c.TokenCollector != nil { // wait until all streaming chunks are collected
|
||||
_ = c.TokenCollector.wait()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
w.ch <- &Event{
|
||||
@ -1271,6 +1267,11 @@ func (n *NodeHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
|
||||
return ctx
|
||||
}
|
||||
|
||||
const (
|
||||
ToolCallIDKey = "call_id"
|
||||
ToolFinishChanKey = "tool_finish_chan"
|
||||
)
|
||||
|
||||
func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
|
||||
input *tool.CallbackInput,
|
||||
) context.Context {
|
||||
@ -1286,13 +1287,35 @@ func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
callID string
|
||||
toolFinishChan chan struct{}
|
||||
)
|
||||
if input.Extra != nil {
|
||||
callIDAny, ok := input.Extra[ToolCallIDKey]
|
||||
if ok {
|
||||
callID = callIDAny.(string)
|
||||
}
|
||||
toolFinishChanAny, ok := input.Extra[ToolFinishChanKey]
|
||||
if ok {
|
||||
toolFinishChan = toolFinishChanAny.(chan struct{})
|
||||
}
|
||||
}
|
||||
|
||||
if len(callID) == 0 {
|
||||
callID = compose.GetToolCallID(ctx)
|
||||
}
|
||||
|
||||
t.ch <- &Event{
|
||||
Type: FunctionCall,
|
||||
Context: GetExeCtx(ctx),
|
||||
functionCall: &entity.FunctionCallInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
Arguments: args,
|
||||
functionCall: &FunctionCallInfo{
|
||||
FunctionCallInfo: &entity.FunctionCallInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: callID,
|
||||
Arguments: args,
|
||||
},
|
||||
toolFinishChan: toolFinishChan,
|
||||
},
|
||||
}
|
||||
|
||||
@ -1306,14 +1329,25 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
|
||||
return ctx
|
||||
}
|
||||
|
||||
var callID string
|
||||
if output.Extra != nil {
|
||||
callIDAny, ok := output.Extra[ToolCallIDKey]
|
||||
if ok {
|
||||
callID = callIDAny.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if len(callID) == 0 {
|
||||
callID = compose.GetToolCallID(ctx)
|
||||
}
|
||||
|
||||
t.ch <- &Event{
|
||||
Type: ToolResponse,
|
||||
Context: GetExeCtx(ctx),
|
||||
toolResponse: &entity.ToolResponseInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
CallID: callID,
|
||||
Response: output.Response,
|
||||
Complete: true,
|
||||
},
|
||||
}
|
||||
|
||||
@ -1352,7 +1386,6 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
|
||||
toolResponse: &entity.ToolResponseInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: callID,
|
||||
Complete: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -1374,7 +1407,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
|
||||
Context: c,
|
||||
toolResponse: &entity.ToolResponseInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
CallID: callID,
|
||||
Response: chunk.Response,
|
||||
},
|
||||
}
|
||||
@ -1398,9 +1431,11 @@ func (t *ToolHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err
|
||||
t.ch <- &Event{
|
||||
Type: ToolError,
|
||||
Context: GetExeCtx(ctx),
|
||||
functionCall: &entity.FunctionCallInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
functionCall: &FunctionCallInfo{
|
||||
FunctionCallInfo: &entity.FunctionCallInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
},
|
||||
},
|
||||
Err: err,
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type TokenCollector struct {
|
||||
@ -90,6 +91,33 @@ func (t *TokenCollector) finishStreamCounting() {
|
||||
}
|
||||
}
|
||||
|
||||
type tokenCollector struct {
|
||||
Key string
|
||||
Usage *model.TokenUsage
|
||||
Parent *TokenCollector
|
||||
}
|
||||
|
||||
func (t *TokenCollector) MarshalJSON() ([]byte, error) {
|
||||
t.wait()
|
||||
return sonic.Marshal(&tokenCollector{
|
||||
Key: t.Key,
|
||||
Usage: t.Usage,
|
||||
Parent: t.Parent,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *TokenCollector) UnmarshalJSON(bytes []byte) error {
|
||||
tc := &tokenCollector{}
|
||||
if err := sonic.Unmarshal(bytes, tc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Key = tc.Key
|
||||
t.Usage = tc.Usage
|
||||
t.Parent = tc.Parent
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTokenCollector(ctx context.Context) *TokenCollector {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
|
||||
@ -64,7 +64,7 @@ type Event struct {
|
||||
|
||||
InterruptEvents []*entity.InterruptEvent
|
||||
|
||||
functionCall *entity.FunctionCallInfo
|
||||
functionCall *FunctionCallInfo
|
||||
toolResponse *entity.ToolResponseInfo
|
||||
|
||||
outputExtractor func(o map[string]any) string
|
||||
@ -75,6 +75,11 @@ type Event struct {
|
||||
nodeCount int32
|
||||
}
|
||||
|
||||
type FunctionCallInfo struct {
|
||||
*entity.FunctionCallInfo
|
||||
toolFinishChan chan struct{}
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
InputToken int64
|
||||
OutputToken int64
|
||||
@ -104,17 +109,3 @@ func (e *Event) GetResumedEventID() int64 {
|
||||
}
|
||||
return e.Context.RootCtx.ResumeEvent.ID
|
||||
}
|
||||
|
||||
func (e *Event) GetFunctionCallInfo() (*entity.FunctionCallInfo, bool) {
|
||||
if e.functionCall == nil {
|
||||
return nil, false
|
||||
}
|
||||
return e.functionCall, true
|
||||
}
|
||||
|
||||
func (e *Event) GetToolResponse() (*entity.ToolResponseInfo, bool) {
|
||||
if e.toolResponse == nil {
|
||||
return nil, false
|
||||
}
|
||||
return e.toolResponse, true
|
||||
}
|
||||
|
||||
@ -672,7 +672,7 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.FunctionCall,
|
||||
FunctionCall: event.functionCall,
|
||||
FunctionCall: event.functionCall.FunctionCallInfo,
|
||||
},
|
||||
}, nil)
|
||||
case ToolResponse:
|
||||
@ -704,8 +704,6 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
},
|
||||
}, nil)
|
||||
case ToolError:
|
||||
// TODO: optimize this log
|
||||
logs.CtxErrorf(ctx, "received tool error event: %v", event)
|
||||
default:
|
||||
panic("unimplemented event type: " + event.Type)
|
||||
}
|
||||
@ -715,8 +713,9 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
|
||||
type fcCacheKey struct{}
|
||||
type fcInfo struct {
|
||||
input *entity.FunctionCallInfo
|
||||
output *entity.ToolResponseInfo
|
||||
input *entity.FunctionCallInfo
|
||||
output *entity.ToolResponseInfo
|
||||
toolFinishChan chan struct{}
|
||||
}
|
||||
|
||||
func HandleExecuteEvent(ctx context.Context,
|
||||
@ -772,7 +771,8 @@ func HandleExecuteEvent(ctx context.Context,
|
||||
lastNodeIsDone = true
|
||||
if wfSuccessEvent != nil {
|
||||
if err = setRootWorkflowSuccess(ctx, wfSuccessEvent, repo, sw); err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to set root workflow success: %v", err)
|
||||
logs.CtxErrorf(ctx, "failed to set root workflow success for workflow %d: %v",
|
||||
wfSuccessEvent.RootWorkflowBasic.ID, err)
|
||||
}
|
||||
return wfSuccessEvent
|
||||
}
|
||||
@ -786,10 +786,12 @@ func HandleExecuteEvent(ctx context.Context,
|
||||
// Add cancellation check timer
|
||||
cancelTicker := time.NewTicker(cancelCheckInterval)
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] cancellable finish, returned event type: %v, workflow id: %d",
|
||||
event.Type, event.Context.RootWorkflowBasic.ID)
|
||||
cancelTicker.Stop() // Clean up timer
|
||||
waitUntilToolFinish(ctx)
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] cancellable wait until tool finished done, workflow id: %d",
|
||||
event.Context.RootWorkflowBasic.ID)
|
||||
cancelTicker.Stop() // Clean up timer
|
||||
if timeoutFn != nil {
|
||||
timeoutFn()
|
||||
}
|
||||
@ -825,6 +827,9 @@ func HandleExecuteEvent(ctx context.Context,
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
|
||||
event.Type, event.Context.RootWorkflowBasic.ID)
|
||||
waitUntilToolFinish(ctx)
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] wait until tool finished done, workflow id: %d",
|
||||
event.Context.RootWorkflowBasic.ID)
|
||||
if timeoutFn != nil {
|
||||
timeoutFn()
|
||||
}
|
||||
@ -859,29 +864,26 @@ func cacheFunctionCall(ctx context.Context, event *Event) {
|
||||
c[event.NodeKey] = make(map[string]*fcInfo)
|
||||
}
|
||||
c[event.NodeKey][event.functionCall.CallID] = &fcInfo{
|
||||
input: event.functionCall,
|
||||
input: event.functionCall.FunctionCallInfo,
|
||||
toolFinishChan: event.functionCall.toolFinishChan,
|
||||
}
|
||||
}
|
||||
|
||||
func cacheToolResponse(ctx context.Context, event *Event) {
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if _, ok := c[event.NodeKey]; !ok {
|
||||
c[event.NodeKey] = make(map[string]*fcInfo)
|
||||
}
|
||||
|
||||
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
|
||||
}
|
||||
|
||||
func cacheToolStreamingResponse(ctx context.Context, event *Event) {
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if _, ok := c[event.NodeKey]; !ok {
|
||||
c[event.NodeKey] = make(map[string]*fcInfo)
|
||||
}
|
||||
if c[event.NodeKey][event.toolResponse.CallID].output == nil {
|
||||
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
|
||||
} else {
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
|
||||
}
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete
|
||||
|
||||
logs.CtxInfof(ctx, "receive tool response: %s, callID: %s",
|
||||
event.toolResponse.Response, event.toolResponse.CallID)
|
||||
}
|
||||
|
||||
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
@ -890,29 +892,17 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
}
|
||||
|
||||
func waitUntilToolFinish(ctx context.Context) {
|
||||
var cnt int
|
||||
outer:
|
||||
for {
|
||||
if cnt > 1000 {
|
||||
return
|
||||
}
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if len(c) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if len(c) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range c {
|
||||
for _, info := range m {
|
||||
if info.output == nil {
|
||||
cnt++
|
||||
continue outer
|
||||
}
|
||||
|
||||
if !info.output.Complete {
|
||||
cnt++
|
||||
continue outer
|
||||
}
|
||||
for _, m := range c {
|
||||
for _, info := range m {
|
||||
if info.toolFinishChan != nil {
|
||||
<-info.toolFinishChan
|
||||
logs.CtxInfof(ctx, "tool finished, callID: %s, pluginID: %v", info.output.CallID,
|
||||
info.input.PluginID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,11 +26,10 @@ import (
|
||||
)
|
||||
|
||||
type workflowToolOption struct {
|
||||
resumeReq *entity.ResumeRequest
|
||||
streamContainer *StreamContainer
|
||||
exeCfg workflowModel.ExecuteConfig
|
||||
allInterruptEvents map[string]*entity.ToolInterruptEvent
|
||||
parentTokenCollector *TokenCollector
|
||||
resumeReq *entity.ResumeRequest
|
||||
streamContainer *StreamContainer
|
||||
exeCfg workflowModel.ExecuteConfig
|
||||
allInterruptEvents map[string]*entity.ToolInterruptEvent
|
||||
}
|
||||
|
||||
func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptEvent) tool.Option {
|
||||
|
||||
@ -17,7 +17,7 @@ require (
|
||||
github.com/bytedance/gopkg v0.1.3
|
||||
github.com/bytedance/mockey v1.2.14
|
||||
github.com/bytedance/sonic v1.14.0
|
||||
github.com/cloudwego/eino v0.3.55
|
||||
github.com/cloudwego/eino v0.4.8
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0
|
||||
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e
|
||||
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
|
||||
@ -284,3 +284,10 @@ require (
|
||||
sigs.k8s.io/yaml v1.3.0 // indirect
|
||||
stathat.com/c/consistent v1.0.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/eino-contrib/jsonschema v1.0.0 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
)
|
||||
|
||||
@ -131,6 +131,8 @@ github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAm
|
||||
github.com/aws/smithy-go v1.22.4 h1:uqXzVZNuNexwc/xrh6Tb56u89WDlJY6HS+KC0S4QSjw=
|
||||
github.com/aws/smithy-go v1.22.4/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
@ -146,6 +148,8 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
|
||||
github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
|
||||
github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM=
|
||||
@ -190,8 +194,8 @@ github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5P
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
|
||||
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/eino v0.3.55 h1:lMZrGtEh0k3qykQTLNXSXuAa98OtF2tS43GMHyvN7nA=
|
||||
github.com/cloudwego/eino v0.3.55/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
|
||||
github.com/cloudwego/eino v0.4.8 h1:wptTU24tQad1mFCHw0+4zSzH+p8dLEBk6HtggPlcvP0=
|
||||
github.com/cloudwego/eino v0.4.8/go.mod h1:1TDlOmwGSsbCJaWB92w9YLZi2FL0WRZoRcD4eMvqikg=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 h1:AuJsMdaTXc+dGUDQp82MifLYK8oiJf4gLQPUETmKISM=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0/go.mod h1:0FZG/KRBl3hGWkNsm55UaXyVa6PDVIy5u+QvboAB+cY=
|
||||
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e h1:46D2fFDbUysA7kUD5x/wK3huneMEvTQfuWcHqI3M6iQ=
|
||||
@ -288,6 +292,8 @@ github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1
|
||||
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
|
||||
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
||||
github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
|
||||
github.com/eino-contrib/jsonschema v1.0.0 h1:dXxbhGNZuI3+xNi8x3JT8AGyoXz6Pff6mRvmpjVl5Ww=
|
||||
github.com/eino-contrib/jsonschema v1.0.0/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4=
|
||||
github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
|
||||
@ -1056,6 +1062,8 @@ github.com/volcengine/volc-sdk-golang v1.0.211 h1:FgwD+1phyy+un4Qk2YqooYtp6XpvND
|
||||
github.com/volcengine/volc-sdk-golang v1.0.211/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.20 h1:+ifZdF7IIIagqF8yVNfk9CmNUl5wgRfU/8orlH+JQhA=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.20/go.mod h1:EyKoi6t6eZxoPNGr2GdFCZti2Skd7MO3eUzx7TtSvNo=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
|
||||
@ -72,7 +72,14 @@ func (q *UTChatModel) Generate(ctx context.Context, in []*schema.Message, _ ...m
|
||||
}
|
||||
|
||||
if msg.ResponseMeta != nil {
|
||||
callbackOut.TokenUsage = (*model.TokenUsage)(msg.ResponseMeta.Usage)
|
||||
callbackOut.TokenUsage = &model.TokenUsage{
|
||||
PromptTokens: msg.ResponseMeta.Usage.PromptTokens,
|
||||
PromptTokenDetails: model.PromptTokenDetails{
|
||||
CachedTokens: msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens,
|
||||
},
|
||||
CompletionTokens: msg.ResponseMeta.Usage.CompletionTokens,
|
||||
TotalTokens: msg.ResponseMeta.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
_ = callbacks.OnEnd(ctx, callbackOut)
|
||||
@ -112,7 +119,14 @@ func (q *UTChatModel) Stream(ctx context.Context, in []*schema.Message, _ ...mod
|
||||
}
|
||||
|
||||
if t.ResponseMeta != nil {
|
||||
callbackOut.TokenUsage = (*model.TokenUsage)(t.ResponseMeta.Usage)
|
||||
callbackOut.TokenUsage = &model.TokenUsage{
|
||||
PromptTokens: t.ResponseMeta.Usage.PromptTokens,
|
||||
PromptTokenDetails: model.PromptTokenDetails{
|
||||
CachedTokens: t.ResponseMeta.Usage.PromptTokenDetails.CachedTokens,
|
||||
},
|
||||
CompletionTokens: t.ResponseMeta.Usage.CompletionTokens,
|
||||
TotalTokens: t.ResponseMeta.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return callbackOut, nil
|
||||
|
||||
Reference in New Issue
Block a user