760 lines
23 KiB
Go
760 lines
23 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package workflow
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
|
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
|
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/maps"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
|
"github.com/coze-dev/coze-studio/backend/types/consts"
|
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
|
)
|
|
|
|
func (w *ApplicationService) CreateApplicationConversationDef(ctx context.Context, req *workflow.CreateProjectConversationDefRequest) (resp *workflow.CreateProjectConversationDefResponse, err error) {
|
|
defer func() {
|
|
if panicErr := recover(); panicErr != nil {
|
|
err = safego.NewPanicErr(panicErr, debug.Stack())
|
|
}
|
|
|
|
if err != nil {
|
|
err = vo.WrapIfNeeded(errno.ErrConversationOfAppOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
|
}
|
|
}()
|
|
|
|
var (
|
|
spaceID = mustParseInt64(req.GetSpaceID())
|
|
appID = mustParseInt64(req.GetProjectID())
|
|
userID = ctxutil.MustGetUIDFromCtx(ctx)
|
|
)
|
|
|
|
if err := checkUserSpace(ctx, userID, spaceID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
uniqueID, err := GetWorkflowDomainSVC().CreateDraftConversationTemplate(ctx, &vo.CreateConversationTemplateMeta{
|
|
AppID: appID,
|
|
SpaceID: spaceID,
|
|
Name: req.GetConversationName(),
|
|
UserID: userID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &workflow.CreateProjectConversationDefResponse{
|
|
UniqueID: strconv.FormatInt(uniqueID, 10),
|
|
SpaceID: req.GetSpaceID(),
|
|
}, err
|
|
}
|
|
|
|
func (w *ApplicationService) UpdateApplicationConversationDef(ctx context.Context, req *workflow.UpdateProjectConversationDefRequest) (resp *workflow.UpdateProjectConversationDefResponse, err error) {
|
|
defer func() {
|
|
if panicErr := recover(); panicErr != nil {
|
|
err = safego.NewPanicErr(panicErr, debug.Stack())
|
|
}
|
|
|
|
if err != nil {
|
|
err = vo.WrapIfNeeded(errno.ErrConversationOfAppOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
|
}
|
|
}()
|
|
var (
|
|
spaceID = mustParseInt64(req.GetSpaceID())
|
|
templateID = mustParseInt64(req.GetUniqueID())
|
|
appID = mustParseInt64(req.GetProjectID())
|
|
userID = ctxutil.MustGetUIDFromCtx(ctx)
|
|
)
|
|
|
|
if err := checkUserSpace(ctx, userID, spaceID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = GetWorkflowDomainSVC().UpdateDraftConversationTemplateName(ctx, appID, userID, templateID, req.GetConversationName())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &workflow.UpdateProjectConversationDefResponse{}, err
|
|
}
|
|
|
|
func (w *ApplicationService) DeleteApplicationConversationDef(ctx context.Context, req *workflow.DeleteProjectConversationDefRequest) (resp *workflow.DeleteProjectConversationDefResponse, err error) {
|
|
defer func() {
|
|
if panicErr := recover(); panicErr != nil {
|
|
err = safego.NewPanicErr(panicErr, debug.Stack())
|
|
}
|
|
|
|
if err != nil {
|
|
err = vo.WrapIfNeeded(errno.ErrConversationOfAppOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
|
}
|
|
}()
|
|
var (
|
|
appID = mustParseInt64(req.GetProjectID())
|
|
templateID = mustParseInt64(req.GetUniqueID())
|
|
)
|
|
if err := checkUserSpace(ctx, ctxutil.MustGetUIDFromCtx(ctx), mustParseInt64(req.GetSpaceID())); err != nil {
|
|
return nil, err
|
|
}
|
|
if req.GetCheckOnly() {
|
|
wfs, err := GetWorkflowDomainSVC().CheckWorkflowsToReplace(ctx, appID, templateID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp = &workflow.DeleteProjectConversationDefResponse{NeedReplace: make([]*workflow.Workflow, 0)}
|
|
for _, wf := range wfs {
|
|
resp.NeedReplace = append(resp.NeedReplace, &workflow.Workflow{
|
|
Name: wf.Name,
|
|
URL: wf.IconURL,
|
|
WorkflowID: strconv.FormatInt(wf.ID, 10),
|
|
})
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
wfID2ConversationName, err := maps.TransformKeyWithErrorCheck(req.GetReplace(), func(k1 string) (int64, error) {
|
|
return strconv.ParseInt(k1, 10, 64)
|
|
})
|
|
|
|
rowsAffected, err := GetWorkflowDomainSVC().DeleteDraftConversationTemplate(ctx, templateID, wfID2ConversationName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if rowsAffected > 0 {
|
|
return &workflow.DeleteProjectConversationDefResponse{
|
|
Success: true,
|
|
}, err
|
|
}
|
|
|
|
rowsAffected, err = GetWorkflowDomainSVC().DeleteDynamicConversation(ctx, vo.Draft, templateID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return nil, fmt.Errorf("delete conversation failed")
|
|
}
|
|
|
|
return &workflow.DeleteProjectConversationDefResponse{
|
|
Success: true,
|
|
}, nil
|
|
|
|
}
|
|
|
|
func (w *ApplicationService) ListApplicationConversationDef(ctx context.Context, req *workflow.ListProjectConversationRequest) (resp *workflow.ListProjectConversationResponse, err error) {
|
|
defer func() {
|
|
if panicErr := recover(); panicErr != nil {
|
|
err = safego.NewPanicErr(panicErr, debug.Stack())
|
|
}
|
|
|
|
if err != nil {
|
|
err = vo.WrapIfNeeded(errno.ErrConversationOfAppOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
|
}
|
|
}()
|
|
var connectorID int64
|
|
if len(req.GetConnectorID()) != 0 {
|
|
connectorID = mustParseInt64(req.GetConnectorID())
|
|
} else {
|
|
connectorID = consts.CozeConnectorID
|
|
}
|
|
var (
|
|
page = mustParseInt64(ternary.IFElse(req.GetCursor() == "", "0", req.GetCursor()))
|
|
size = req.GetLimit()
|
|
userID = ctxutil.MustGetUIDFromCtx(ctx)
|
|
spaceID = mustParseInt64(req.GetSpaceID())
|
|
appID = mustParseInt64(req.GetProjectID())
|
|
version = req.ProjectVersion
|
|
listConversationMeta = vo.ListConversationMeta{
|
|
APPID: appID,
|
|
UserID: userID,
|
|
ConnectorID: connectorID,
|
|
}
|
|
)
|
|
|
|
if err := checkUserSpace(ctx, userID, spaceID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
env := ternary.IFElse(req.GetCreateEnv() == workflow.CreateEnv_Draft, vo.Draft, vo.Online)
|
|
if req.GetCreateMethod() == workflow.CreateMethod_ManualCreate {
|
|
templates, err := GetWorkflowDomainSVC().ListConversationTemplate(ctx, env, &vo.ListConversationTemplatePolicy{
|
|
AppID: appID,
|
|
Page: &vo.Page{
|
|
Page: int32(page),
|
|
Size: int32(size),
|
|
},
|
|
NameLike: ternary.IFElse(len(req.GetNameLike()) == 0, nil, ptr.Of(req.GetNameLike())),
|
|
Version: version,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stsConversations, err := GetWorkflowDomainSVC().MGetStaticConversation(ctx, env, userID, connectorID, slices.Transform(templates, func(a *entity.ConversationTemplate) int64 {
|
|
return a.TemplateID
|
|
}))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stsConversationMap := slices.ToMap(stsConversations, func(e *entity.StaticConversation) (int64, *entity.StaticConversation) {
|
|
return e.TemplateID, e
|
|
})
|
|
|
|
resp = &workflow.ListProjectConversationResponse{Data: make([]*workflow.ProjectConversation, 0)}
|
|
for _, tmpl := range templates {
|
|
conversationID := ""
|
|
if c, ok := stsConversationMap[tmpl.TemplateID]; ok {
|
|
conversationID = strconv.FormatInt(c.ConversationID, 10)
|
|
}
|
|
resp.Data = append(resp.Data, &workflow.ProjectConversation{
|
|
UniqueID: strconv.FormatInt(tmpl.TemplateID, 10),
|
|
ConversationName: tmpl.Name,
|
|
ConversationID: conversationID,
|
|
})
|
|
}
|
|
}
|
|
|
|
if req.GetCreateMethod() == workflow.CreateMethod_NodeCreate {
|
|
dyConversations, err := GetWorkflowDomainSVC().ListDynamicConversation(ctx, env, &vo.ListConversationPolicy{
|
|
ListConversationMeta: listConversationMeta,
|
|
Page: &vo.Page{
|
|
Page: int32(page),
|
|
Size: int32(size),
|
|
},
|
|
NameLike: ternary.IFElse(len(req.GetNameLike()) == 0, nil, ptr.Of(req.GetNameLike())),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp = &workflow.ListProjectConversationResponse{Data: make([]*workflow.ProjectConversation, 0, len(dyConversations))}
|
|
resp.Data = append(resp.Data, slices.Transform(dyConversations, func(a *entity.DynamicConversation) *workflow.ProjectConversation {
|
|
return &workflow.ProjectConversation{
|
|
UniqueID: strconv.FormatInt(a.ID, 10),
|
|
ConversationName: a.Name,
|
|
ConversationID: strconv.FormatInt(a.ConversationID, 10),
|
|
}
|
|
})...)
|
|
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workflow.ChatFlowRunRequest) (
|
|
_ *schema.StreamReader[[]*workflow.ChatFlowRunResponse], err error) {
|
|
defer func() {
|
|
if panicErr := recover(); panicErr != nil {
|
|
err = safego.NewPanicErr(panicErr, debug.Stack())
|
|
}
|
|
|
|
if err != nil {
|
|
err = vo.WrapIfNeeded(errno.ErrChatFlowRoleOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
|
}
|
|
}()
|
|
|
|
if len(req.GetAdditionalMessages()) == 0 {
|
|
return nil, fmt.Errorf("additional_messages is requird")
|
|
}
|
|
|
|
messages := req.GetAdditionalMessages()
|
|
|
|
lastUserMessage := messages[len(req.GetAdditionalMessages())-1]
|
|
if lastUserMessage.Role != "user" {
|
|
return nil, errors.New("the role of the last day message must be user")
|
|
}
|
|
|
|
var parameters = make(map[string]any)
|
|
if len(req.GetParameters()) > 0 {
|
|
err := sonic.UnmarshalString(req.GetParameters(), parameters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
var (
|
|
isDebug = req.GetExecuteMode() == "DEBUG"
|
|
appID, agentID *int64
|
|
resolveAppID int64
|
|
connectorID, userID, conversationID int64
|
|
version string
|
|
locator vo.Locator
|
|
)
|
|
if req.IsSetAppID() {
|
|
appID = ptr.Of(mustParseInt64(req.GetAppID()))
|
|
resolveAppID = mustParseInt64(req.GetAppID())
|
|
}
|
|
if req.IsSetBotID() {
|
|
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
|
|
resolveAppID = mustParseInt64(req.GetBotID())
|
|
}
|
|
|
|
if appID != nil && agentID != nil {
|
|
return nil, errors.New("project_id and bot_id cannot be set at the same time")
|
|
}
|
|
|
|
if isDebug {
|
|
userID = ctxutil.MustGetUIDFromCtx(ctx)
|
|
connectorID = mustParseInt64(req.GetConnectorID())
|
|
locator = vo.FromDraft
|
|
|
|
} else {
|
|
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
|
|
userID = apiKeyInfo.UserID
|
|
connectorID = apiKeyInfo.ConnectorID
|
|
meta, err := GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
|
|
ID: mustParseInt64(req.GetWorkflowID()),
|
|
MetaOnly: true,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if meta.LatestPublishedVersion == nil {
|
|
return nil, vo.NewError(errno.ErrWorkflowNotPublished)
|
|
}
|
|
if req.IsSetVersion() {
|
|
version = req.GetVersion()
|
|
locator = vo.FromSpecificVersion
|
|
} else {
|
|
version = meta.GetLatestVersion()
|
|
locator = vo.FromLatestVersion
|
|
}
|
|
}
|
|
|
|
if req.IsSetConversationID() {
|
|
conversationID = mustParseInt64(req.GetConversationID())
|
|
} else {
|
|
conversationName, ok := parameters["CONVERSATION_NAME"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("conversation name is requried")
|
|
}
|
|
cID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conversationID = cID
|
|
}
|
|
|
|
roundID, err := w.IDGenerator.GenID(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
historyMessages, err := w.makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, messages[:len(req.GetAdditionalMessages())-1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messageClient := crossmessage.DefaultSVC()
|
|
if len(historyMessages) > 0 {
|
|
g := taskgroup.NewTaskGroup(ctx, len(historyMessages))
|
|
for _, hm := range historyMessages {
|
|
hMsg := hm
|
|
g.Go(func() error {
|
|
_, err := messageClient.Create(ctx, hMsg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
err = g.Wait()
|
|
if err != nil {
|
|
logs.CtxWarnf(ctx, "create history message failed, err=%v", err)
|
|
}
|
|
}
|
|
|
|
userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, message.MessageTypeQuestion, lastUserMessage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = messageClient.Create(ctx, userMessage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
exeCfg := vo.ExecuteConfig{
|
|
ID: mustParseInt64(req.GetWorkflowID()),
|
|
From: locator,
|
|
Version: version,
|
|
Operator: userID,
|
|
Mode: ternary.IFElse(isDebug, vo.ExecuteModeDebug, vo.ExecuteModeRelease),
|
|
AppID: appID,
|
|
AgentID: agentID,
|
|
ConnectorID: connectorID,
|
|
ConnectorUID: strconv.FormatInt(userID, 10),
|
|
TaskType: vo.TaskTypeForeground,
|
|
SyncPattern: vo.SyncPatternStream,
|
|
InputFailFast: true,
|
|
BizType: vo.BizTypeWorkflow,
|
|
ConversationID: ptr.Of(conversationID),
|
|
RoundID: ptr.Of(roundID),
|
|
EnterMessage: lastUserMessage,
|
|
Cancellable: isDebug == true,
|
|
}
|
|
|
|
parameters["USER_INPUT"], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sr, err := GetWorkflowDomainSVC().StreamExecute(ctx, exeCfg, parameters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return schema.StreamReaderWithConvert(sr, convertToChatFlowRunResponseList(ctx, resolveAppID, conversationID, roundID, mustParseInt64(req.GetWorkflowID()))), nil
|
|
|
|
}
|
|
|
|
func convertToChatFlowRunResponseList(ctx context.Context, appID int64, conversationID, roundID int64, workflowID int64) func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) {
|
|
var (
|
|
spaceID int64
|
|
executeID int64
|
|
|
|
hasFirstMessage = false
|
|
messageOutput string
|
|
messageID int64
|
|
outputCount int32
|
|
inputCount int32
|
|
)
|
|
var getOrUpdateMessage = func(msg string, role schema.RoleType) error {
|
|
entityMessage := &message.Message{
|
|
AgentID: appID,
|
|
RunID: roundID,
|
|
Content: msg,
|
|
ConversationID: conversationID,
|
|
ContentType: message.ContentTypeText,
|
|
Role: role,
|
|
MessageType: message.MessageTypeAnswer,
|
|
}
|
|
if hasFirstMessage {
|
|
entityMessage.ID = messageID
|
|
_, err := crossmessage.DefaultSVC().Edit(ctx, entityMessage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
m, err := crossmessage.DefaultSVC().Create(ctx, entityMessage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
messageID = m.ID
|
|
hasFirstMessage = true
|
|
}
|
|
return nil
|
|
|
|
}
|
|
|
|
return func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) {
|
|
if msg.StateMessage != nil {
|
|
if executeID > 0 && executeID != msg.StateMessage.ExecuteID {
|
|
return nil, schema.ErrNoValue
|
|
}
|
|
switch msg.StateMessage.Status {
|
|
case entity.WorkflowSuccess:
|
|
chatDoneEvent := &vo.ChatFlowDetail{
|
|
ID: strconv.FormatInt(roundID, 10),
|
|
ConversationID: strconv.FormatInt(conversationID, 10),
|
|
BotID: strconv.FormatInt(appID, 10),
|
|
Status: vo.Completed,
|
|
ExecuteID: strconv.FormatInt(executeID, 10),
|
|
Usage: &vo.Usage{
|
|
InputTokens: ptr.Of(inputCount),
|
|
OutputTokens: ptr.Of(outputCount),
|
|
TokenCount: ptr.Of(outputCount + inputCount),
|
|
},
|
|
}
|
|
data, err := sonic.MarshalString(chatDoneEvent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
doneData, err := sonic.MarshalString(map[string]interface{}{
|
|
"debug_url": fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []*workflow.ChatFlowRunResponse{
|
|
{
|
|
Event: string(vo.ChatFlowCompleted),
|
|
Data: data,
|
|
},
|
|
{
|
|
Event: string(vo.ChatFlowDone),
|
|
Data: doneData,
|
|
},
|
|
}, err
|
|
case entity.WorkflowFailed:
|
|
var wfe vo.WorkflowError
|
|
if !errors.As(msg.StateMessage.LastError, &wfe) {
|
|
panic("stream run last error is not a WorkflowError")
|
|
}
|
|
chatFailedEvent := &vo.ErrorDetail{
|
|
Code: strconv.Itoa(int(wfe.Code())),
|
|
Msg: wfe.Msg(),
|
|
DebugUrl: wfe.DebugURL(),
|
|
}
|
|
data, err := sonic.MarshalString(chatFailedEvent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return []*workflow.ChatFlowRunResponse{
|
|
{
|
|
Event: string(vo.ChatFlowError),
|
|
Data: data,
|
|
},
|
|
}, err
|
|
|
|
case entity.WorkflowCancel:
|
|
// do nothing
|
|
case entity.WorkflowInterrupted:
|
|
// interrupted
|
|
fmt.Println("workflow interrupted")
|
|
case entity.WorkflowRunning:
|
|
executeID = msg.StateMessage.ExecuteID
|
|
spaceID = msg.StateMessage.SpaceID
|
|
|
|
responses = make([]*workflow.ChatFlowRunResponse, 0)
|
|
chatEvent := &vo.ChatFlowDetail{
|
|
ID: strconv.FormatInt(roundID, 10),
|
|
ConversationID: strconv.FormatInt(conversationID, 10),
|
|
Status: vo.Created,
|
|
ExecuteID: strconv.FormatInt(executeID, 10),
|
|
}
|
|
data, err := sonic.MarshalString(chatEvent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
responses = append(responses, &workflow.ChatFlowRunResponse{
|
|
Event: string(vo.ChatFlowCreated),
|
|
Data: data,
|
|
})
|
|
|
|
chatEvent.Status = vo.InProgress
|
|
data, err = sonic.MarshalString(chatEvent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
responses = append(responses, &workflow.ChatFlowRunResponse{
|
|
Event: string(vo.ChatFlowInProgress),
|
|
Data: data,
|
|
})
|
|
return responses, nil
|
|
|
|
default:
|
|
return nil, schema.ErrNoValue
|
|
}
|
|
}
|
|
if msg.DataMessage != nil {
|
|
if msg.Type != entity.Answer {
|
|
return nil, schema.ErrNoValue
|
|
}
|
|
// stream run will skip all messages from workflow tools
|
|
if executeID > 0 && executeID != msg.DataMessage.ExecuteID {
|
|
return nil, schema.ErrNoValue
|
|
}
|
|
|
|
messageOutput += msg.Content
|
|
dataMessage := msg.DataMessage
|
|
if dataMessage.Usage != nil {
|
|
inputCount += int32(msg.DataMessage.Usage.InputTokens)
|
|
outputCount += int32(msg.DataMessage.Usage.OutputTokens)
|
|
}
|
|
|
|
err = getOrUpdateMessage(messageOutput, dataMessage.Role)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
messageEvent := &vo.MessageDetail{
|
|
ID: strconv.FormatInt(messageID, 10),
|
|
ChatID: strconv.FormatInt(roundID, 10),
|
|
ConversationID: strconv.FormatInt(conversationID, 10),
|
|
BotID: strconv.FormatInt(appID, 10),
|
|
Role: string(dataMessage.Role),
|
|
Type: string(dataMessage.Type),
|
|
Content: msg.Content,
|
|
ContentType: string(message.ContentTypeText),
|
|
}
|
|
|
|
data, err := sonic.MarshalString(messageEvent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []*workflow.ChatFlowRunResponse{
|
|
{
|
|
Event: ternary.IFElse(msg.Last, string(vo.ChatFlowMessageCompleted), string(vo.ChatFlowMessageDelta)),
|
|
Data: data,
|
|
},
|
|
}, nil
|
|
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message *workflow.EnterMessage) (string, error) {
|
|
type content struct {
|
|
Type string `json:"type,omitempty"`
|
|
FileID *string `json:"file_id"`
|
|
Text *string `json:"text"`
|
|
}
|
|
if message.ContentType == "text" {
|
|
return message.Content, nil
|
|
} else if message.ContentType == "object_string" {
|
|
contents := make([]content, 0)
|
|
err := sonic.UnmarshalString(message.Content, &contents)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
texts := make([]string, 0)
|
|
urls := make([]string, 0)
|
|
for _, ct := range contents {
|
|
if ct.Text != nil && len(*ct.Text) > 0 {
|
|
texts = append(texts, *ct.Text)
|
|
}
|
|
if ct.FileID != nil && len(*ct.FileID) > 0 {
|
|
u, err := w.ImageX.GetResourceURL(ctx, *ct.FileID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
urls = append(urls, u.URL)
|
|
}
|
|
}
|
|
|
|
return strings.Join(texts, ",") + strings.Join(urls, ","), nil
|
|
|
|
} else {
|
|
return "", fmt.Errorf("invalid message ccontent type %v", message.ContentType)
|
|
}
|
|
|
|
}
|
|
func (w *ApplicationService) makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID int64, userID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
|
|
|
var (
|
|
rID int64
|
|
err error
|
|
userRole = "user"
|
|
assistantRole = "assistant"
|
|
)
|
|
|
|
historyMessages := make([]*message.Message, 0, len(messages))
|
|
|
|
for _, msg := range messages {
|
|
if msg.Role == userRole {
|
|
rID, err = w.IDGenerator.GenID(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else if msg.Role == assistantRole && rID == 0 {
|
|
continue
|
|
} else {
|
|
return nil, fmt.Errorf("invalid role type %v", msg.Role)
|
|
}
|
|
messageType := ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer)
|
|
|
|
m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, messageType, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
historyMessages = append(historyMessages, m)
|
|
|
|
}
|
|
return historyMessages, nil
|
|
}
|
|
|
|
func toConversationMessage(_ context.Context, appID int64, cid int64, userID int64, roundID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
|
type content struct {
|
|
Type string `json:"type"`
|
|
FileID *string `json:"file_id"`
|
|
Text *string `json:"text"`
|
|
}
|
|
if msg.ContentType == "text" {
|
|
return &message.Message{
|
|
Role: schema.User,
|
|
ConversationID: cid,
|
|
AgentID: appID,
|
|
RunID: roundID,
|
|
Content: msg.Content,
|
|
ContentType: message.ContentTypeText,
|
|
MessageType: messageType,
|
|
UserID: strconv.FormatInt(userID, 10),
|
|
}, nil
|
|
|
|
} else if msg.ContentType == "object_string" {
|
|
contents := make([]*content, 0)
|
|
err := sonic.UnmarshalString(msg.Content, &contents)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
m := &message.Message{
|
|
Role: schema.User,
|
|
MessageType: messageType,
|
|
ConversationID: cid,
|
|
UserID: strconv.FormatInt(userID, 10),
|
|
RunID: roundID,
|
|
Content: msg.Content,
|
|
ContentType: message.ContentTypeMix,
|
|
MultiContent: make([]*message.InputMetaData, 0, len(contents)),
|
|
}
|
|
|
|
for _, ct := range contents {
|
|
if ct.Text != nil {
|
|
m.MultiContent = append(m.MultiContent, &message.InputMetaData{
|
|
Type: message.InputTypeText,
|
|
Text: *ct.Text,
|
|
})
|
|
} else if ct.FileID != nil {
|
|
m.MultiContent = append(m.MultiContent, &message.InputMetaData{
|
|
Type: message.InputType(ct.Type),
|
|
FileData: []*message.FileData{
|
|
{Url: *ct.FileID},
|
|
},
|
|
})
|
|
} else {
|
|
return nil, fmt.Errorf("invalid input type %v", ct.Type)
|
|
}
|
|
}
|
|
return m, nil
|
|
} else {
|
|
return nil, fmt.Errorf("invalid message content type %v", msg.ContentType)
|
|
}
|
|
}
|