feat: Handle section_id in chatflow interfaces and nodes
This commit is contained in:
@ -497,6 +497,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
appID, agentID *int64
|
||||
resolveAppID int64
|
||||
conversationID int64
|
||||
sectionID int64
|
||||
version string
|
||||
locator vo.Locator
|
||||
apiKeyInfo = ctxutil.GetApiAuthFromCtx(ctx)
|
||||
@ -541,16 +542,22 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
|
||||
if req.IsSetConversationID() {
|
||||
conversationID = mustParseInt64(req.GetConversationID())
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sectionID = cInfo.SectionID
|
||||
} else {
|
||||
conversationName, ok := parameters["CONVERSATION_NAME"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("conversation name is requried")
|
||||
}
|
||||
cID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
|
||||
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conversationID = cID
|
||||
sectionID = sID
|
||||
}
|
||||
|
||||
roundID, err := w.IDGenerator.GenID(ctx)
|
||||
@ -558,7 +565,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
return nil, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, message.MessageTypeQuestion, lastUserMessage)
|
||||
userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -594,6 +601,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
ConversationID: ptr.Of(conversationID),
|
||||
RoundID: ptr.Of(roundID),
|
||||
InitRoundID: ptr.Of(roundID),
|
||||
SectionID: ptr.Of(sectionID),
|
||||
})
|
||||
if err != nil {
|
||||
uErr := unbinding()
|
||||
@ -608,12 +616,13 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
conversationID: conversationID,
|
||||
roundID: roundID,
|
||||
workflowID: mustParseInt64(req.GetWorkflowID()),
|
||||
sectionID: sectionID,
|
||||
unbinding: unbinding,
|
||||
})), nil
|
||||
|
||||
}
|
||||
|
||||
historyMessages, err := w.makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, messages[:len(req.GetAdditionalMessages())-1])
|
||||
historyMessages, err := w.makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, sectionID, messages[:len(req.GetAdditionalMessages())-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -658,6 +667,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
ConversationID: ptr.Of(conversationID),
|
||||
RoundID: ptr.Of(roundID),
|
||||
InitRoundID: ptr.Of(roundID),
|
||||
SectionID: ptr.Of(sectionID),
|
||||
UserMessage: userSchemaMessage,
|
||||
Cancellable: isDebug,
|
||||
}
|
||||
@ -677,6 +687,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
|
||||
conversationID: conversationID,
|
||||
roundID: roundID,
|
||||
workflowID: mustParseInt64(req.GetWorkflowID()),
|
||||
sectionID: sectionID,
|
||||
unbinding: unbinding,
|
||||
})), nil
|
||||
|
||||
@ -718,7 +729,7 @@ func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message
|
||||
}
|
||||
|
||||
}
|
||||
func (w *ApplicationService) makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID int64, userID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
||||
func (w *ApplicationService) makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, userID, sectionID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
|
||||
|
||||
var (
|
||||
rID int64
|
||||
@ -739,7 +750,7 @@ func (w *ApplicationService) makeChatFlowHistoryMessages(ctx context.Context, ap
|
||||
return nil, fmt.Errorf("invalid role type %v", msg.Role)
|
||||
}
|
||||
|
||||
m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
|
||||
m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -776,7 +787,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req
|
||||
if !req.GetGetOrCreate() {
|
||||
cID, err = GetWorkflowDomainSVC().UpdateConversation(ctx, env, appID, req.GetConnectorId(), userID, req.GetConversationMame())
|
||||
} else {
|
||||
cID, err = GetWorkflowDomainSVC().GetOrCreateConversation(ctx, env, appID, req.GetConnectorId(), userID, req.GetConversationMame())
|
||||
cID, _, err = GetWorkflowDomainSVC().GetOrCreateConversation(ctx, env, appID, req.GetConnectorId(), userID, req.GetConversationMame())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -795,7 +806,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toConversationMessage(_ context.Context, appID int64, cid int64, userID int64, roundID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
||||
func toConversationMessage(_ context.Context, appID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
|
||||
type content struct {
|
||||
Type string `json:"type"`
|
||||
FileID *string `json:"file_id"`
|
||||
@ -811,6 +822,7 @@ func toConversationMessage(_ context.Context, appID int64, cid int64, userID int
|
||||
ContentType: message.ContentTypeText,
|
||||
MessageType: messageType,
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
SectionID: sectionID,
|
||||
}, nil
|
||||
|
||||
} else if msg.ContentType == "object_string" {
|
||||
@ -829,6 +841,7 @@ func toConversationMessage(_ context.Context, appID int64, cid int64, userID int
|
||||
Content: msg.Content,
|
||||
ContentType: message.ContentTypeMix,
|
||||
MultiContent: make([]*message.InputMetaData, 0, len(contents)),
|
||||
SectionID: sectionID,
|
||||
}
|
||||
|
||||
for _, ct := range contents {
|
||||
@ -935,6 +948,7 @@ type convertToChatFlowInfo struct {
|
||||
conversationID int64
|
||||
roundID int64
|
||||
workflowID int64
|
||||
sectionID int64
|
||||
unbinding func() error
|
||||
}
|
||||
|
||||
@ -960,6 +974,7 @@ func convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlo
|
||||
conversationID = info.conversationID
|
||||
roundID = info.roundID
|
||||
workflowID = info.workflowID
|
||||
sectionID = info.sectionID
|
||||
unbinding = info.unbinding
|
||||
|
||||
spaceID int64
|
||||
@ -975,6 +990,7 @@ func convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlo
|
||||
entityMessage := &message.Message{
|
||||
AgentID: appID,
|
||||
RunID: roundID,
|
||||
SectionID: sectionID,
|
||||
Content: msg,
|
||||
ConversationID: conversationID,
|
||||
ContentType: contentType,
|
||||
|
||||
@ -41,7 +41,7 @@ func NewConversationRepository() *ConversationRepository {
|
||||
return &ConversationRepository{}
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) CreateConversation(ctx context.Context, req *conversation.CreateConversationRequest) (int64, error) {
|
||||
func (c *ConversationRepository) CreateConversation(ctx context.Context, req *conversation.CreateConversationRequest) (int64, int64, error) {
|
||||
ret, err := crossconversation.DefaultSVC().Create(ctx, &entity.CreateMeta{
|
||||
AgentID: req.AppID,
|
||||
UserID: req.UserID,
|
||||
@ -49,10 +49,14 @@ func (c *ConversationRepository) CreateConversation(ctx context.Context, req *co
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return ret.ID, nil
|
||||
return ret.ID, ret.SectionID, nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
|
||||
return crossconversation.DefaultSVC().GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) CreateMessage(ctx context.Context, req *conversation.CreateMessageRequest) (int64, error) {
|
||||
@ -64,8 +68,9 @@ func (c *ConversationRepository) CreateMessage(ctx context.Context, req *convers
|
||||
UserID: strconv.FormatInt(req.UserID, 10),
|
||||
AgentID: req.AppID,
|
||||
RunID: req.RunID,
|
||||
SectionID: req.SectionID,
|
||||
}
|
||||
if msg.Role == "user" {
|
||||
if msg.Role == schema.User {
|
||||
msg.MessageType = message.MessageTypeQuestion
|
||||
} else {
|
||||
msg.MessageType = message.MessageTypeAnswer
|
||||
@ -88,13 +93,12 @@ func (c *ConversationRepository) MessageList(ctx context.Context, req *conversat
|
||||
}
|
||||
if req.BeforeID != nil {
|
||||
lm.Cursor, _ = strconv.ParseInt(*req.BeforeID, 10, 64)
|
||||
lm.Direction = msgentity.ScrollPageDirectionPrev
|
||||
lm.Direction = msgentity.ScrollPageDirectionNext
|
||||
}
|
||||
if req.AfterID != nil {
|
||||
lm.Cursor, _ = strconv.ParseInt(*req.AfterID, 10, 64)
|
||||
lm.Direction = msgentity.ScrollPageDirectionNext
|
||||
lm.Direction = msgentity.ScrollPageDirectionPrev
|
||||
}
|
||||
lm.Direction = msgentity.ScrollPageDirectionNext
|
||||
lr, err := crossmessage.DefaultSVC().List(ctx, lm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -119,14 +123,14 @@ func (c *ConversationRepository) MessageList(ctx context.Context, req *conversat
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *ConversationRepository) ClearConversationHistory(ctx context.Context, req *conversation.ClearConversationHistoryReq) error {
|
||||
_, err := crossconversation.DefaultSVC().NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
|
||||
func (c *ConversationRepository) ClearConversationHistory(ctx context.Context, req *conversation.ClearConversationHistoryReq) (int64, error) {
|
||||
resp, err := crossconversation.DefaultSVC().NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
|
||||
ID: req.ConversationID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
return nil
|
||||
return resp.SectionID, nil
|
||||
|
||||
}
|
||||
|
||||
@ -153,6 +157,7 @@ func (c *ConversationRepository) GetLatestRunIDs(ctx context.Context, req *conve
|
||||
ConversationID: req.ConversationID,
|
||||
AgentID: req.AppID,
|
||||
Limit: int32(req.Rounds),
|
||||
SectionID: req.SectionID,
|
||||
}
|
||||
|
||||
if req.InitRunID != nil {
|
||||
|
||||
@ -67,7 +67,7 @@ type ConversationService interface {
|
||||
ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error)
|
||||
ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error
|
||||
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
|
||||
GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error)
|
||||
GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error)
|
||||
UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error)
|
||||
}
|
||||
|
||||
@ -112,7 +112,7 @@ type ToolFromWorkflow interface {
|
||||
GetWorkflow() *entity.Workflow
|
||||
}
|
||||
|
||||
type ConversationIDGenerator func(ctx context.Context, appID int64, userID, connectorID int64) (int64, error)
|
||||
type ConversationIDGenerator func(ctx context.Context, appID int64, userID, connectorID int64) (int64, int64, error)
|
||||
|
||||
type ConversationRepository interface {
|
||||
CreateDraftConversationTemplate(ctx context.Context, template *vo.CreateConversationTemplateMeta) (int64, error)
|
||||
@ -122,8 +122,8 @@ type ConversationRepository interface {
|
||||
DeleteDynamicConversation(ctx context.Context, env vo.Env, id int64) (int64, error)
|
||||
ListConversationTemplate(ctx context.Context, env vo.Env, policy *vo.ListConversationTemplatePolicy) ([]*entity.ConversationTemplate, error)
|
||||
MGetStaticConversation(ctx context.Context, env vo.Env, userID, connectorID int64, templateIDs []int64) ([]*entity.StaticConversation, error)
|
||||
GetOrCreateStaticConversation(ctx context.Context, env vo.Env, idGen ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, bool, error)
|
||||
GetOrCreateDynamicConversation(ctx context.Context, env vo.Env, idGen ConversationIDGenerator, meta *vo.CreateDynamicConversation) (int64, bool, error)
|
||||
GetOrCreateStaticConversation(ctx context.Context, env vo.Env, idGen ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, int64, bool, error)
|
||||
GetOrCreateDynamicConversation(ctx context.Context, env vo.Env, idGen ConversationIDGenerator, meta *vo.CreateDynamicConversation) (int64, int64, bool, error)
|
||||
GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error)
|
||||
GetStaticConversationByTemplateID(ctx context.Context, env vo.Env, userID, connectorID, templateID int64) (*entity.StaticConversation, bool, error)
|
||||
ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error)
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
)
|
||||
|
||||
type CreateConversationRequest struct {
|
||||
@ -37,6 +38,7 @@ type CreateMessageRequest struct {
|
||||
UserID int64
|
||||
AppID int64
|
||||
RunID int64
|
||||
SectionID int64
|
||||
}
|
||||
|
||||
type MessageListRequest struct {
|
||||
@ -96,6 +98,7 @@ type GetLatestRunIDsRequest struct {
|
||||
UserID int64
|
||||
AppID int64
|
||||
Rounds int64
|
||||
SectionID int64
|
||||
InitRunID *int64
|
||||
}
|
||||
type ClearConversationHistoryReq struct {
|
||||
@ -124,12 +127,13 @@ type GetMessagesByRunIDsResponse struct {
|
||||
|
||||
//go:generate mockgen -destination conversationmock/conversation_mock.go --package conversationmock -source conversation.go
|
||||
type ConversationManager interface {
|
||||
CreateConversation(ctx context.Context, req *CreateConversationRequest) (int64, error)
|
||||
CreateConversation(ctx context.Context, req *CreateConversationRequest) (int64, int64, error)
|
||||
CreateMessage(ctx context.Context, req *CreateMessageRequest) (int64, error)
|
||||
MessageList(ctx context.Context, req *MessageListRequest) (*MessageListResponse, error)
|
||||
GetLatestRunIDs(ctx context.Context, req *GetLatestRunIDsRequest) ([]int64, error)
|
||||
GetMessagesByRunIDs(ctx context.Context, req *GetMessagesByRunIDsRequest) (*GetMessagesByRunIDsResponse, error)
|
||||
ClearConversationHistory(ctx context.Context, req *ClearConversationHistoryReq) error
|
||||
ClearConversationHistory(ctx context.Context, req *ClearConversationHistoryReq) (int64, error)
|
||||
DeleteMessage(ctx context.Context, req *DeleteMessageRequest) error
|
||||
EditMessage(ctx context.Context, req *EditMessageRequest) error
|
||||
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
|
||||
}
|
||||
|
||||
@ -44,6 +44,7 @@ type ExecuteConfig struct {
|
||||
ConversationID *int64 // if workflow is chat flow, conversation id is required
|
||||
UserMessage *schema.Message
|
||||
ConversationHistory []*conversation.Message
|
||||
SectionID *int64
|
||||
}
|
||||
|
||||
type ExecuteMode string
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
@ -123,12 +124,15 @@ func (c *ClearConversationHistory) Invoke(ctx context.Context, in map[string]any
|
||||
}, nil
|
||||
}
|
||||
|
||||
err = c.Manager.ClearConversationHistory(ctx, &conversation.ClearConversationHistoryReq{
|
||||
sectionID, err := c.Manager.ClearConversationHistory(ctx, &conversation.ClearConversationHistoryReq{
|
||||
ConversationID: conversationID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrConversationNodesNotAvailable, err)
|
||||
}
|
||||
if execCtx.ExeCfg.SectionID != nil {
|
||||
atomic.StoreInt64(execCtx.ExeCfg.SectionID, sectionID)
|
||||
}
|
||||
return map[string]any{
|
||||
"isSuccess": true,
|
||||
}, nil
|
||||
|
||||
@ -75,7 +75,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
version = execCtx.ExeCfg.Version
|
||||
connectorID = execCtx.ExeCfg.ConnectorID
|
||||
userID = execCtx.ExeCfg.Operator
|
||||
conversationIDGenerator = workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (int64, error) {
|
||||
conversationIDGenerator = workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (int64, int64, error) {
|
||||
return c.Manager.CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
AppID: appID,
|
||||
UserID: userID,
|
||||
@ -106,7 +106,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
}
|
||||
|
||||
if existed {
|
||||
cID, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
cID, _, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: ptr.From(appID),
|
||||
TemplateID: template.TemplateID,
|
||||
UserID: userID,
|
||||
@ -122,7 +122,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
cID, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
cID, _, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
AppID: ptr.From(appID),
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
|
||||
@ -76,7 +76,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env,
|
||||
return 0, vo.WrapError(errno.ErrConversationNodeInvalidOperation, err)
|
||||
}
|
||||
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (int64, error) {
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (int64, int64, error) {
|
||||
return c.Creator.CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
AppID: appID,
|
||||
UserID: userID,
|
||||
@ -86,7 +86,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env,
|
||||
|
||||
var conversationID int64
|
||||
if isExist {
|
||||
cID, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
cID, _, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: ptr.From(appID),
|
||||
TemplateID: template.TemplateID,
|
||||
UserID: userID,
|
||||
@ -229,6 +229,21 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
}
|
||||
}
|
||||
|
||||
var sectionID int64
|
||||
if isCurrentConversation {
|
||||
if execCtx.ExeCfg.SectionID != nil {
|
||||
sectionID = *execCtx.ExeCfg.SectionID
|
||||
} else {
|
||||
return nil, vo.WrapError(errno.ErrInvalidParameter, errors.New("section id is required"))
|
||||
}
|
||||
} else {
|
||||
cInfo, err := c.Creator.GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sectionID = cInfo.SectionID
|
||||
}
|
||||
|
||||
mID, err := c.Creator.CreateMessage(ctx, &conversation.CreateMessageRequest{
|
||||
ConversationID: conversationID,
|
||||
Role: role,
|
||||
@ -237,6 +252,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
UserID: userID,
|
||||
AppID: resolvedAppID,
|
||||
RunID: runID,
|
||||
SectionID: sectionID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create message: %w", err)
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gorm"
|
||||
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
@ -416,17 +417,17 @@ func (r *RepositoryImpl) listOnlineDynamicConversation(ctx context.Context, poli
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetOrCreateStaticConversation(ctx context.Context, env vo.Env, idGen workflow.ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, bool, error) {
|
||||
func (r *RepositoryImpl) GetOrCreateStaticConversation(ctx context.Context, env vo.Env, idGen workflow.ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, int64, bool, error) {
|
||||
if env == vo.Draft {
|
||||
return r.getOrCreateDraftStaticConversation(ctx, idGen, meta)
|
||||
} else if env == vo.Online {
|
||||
return r.getOrCreateOnlineStaticConversation(ctx, idGen, meta)
|
||||
} else {
|
||||
return 0, false, fmt.Errorf("unknown env %v", env)
|
||||
return 0, 0, false, fmt.Errorf("unknown env %v", env)
|
||||
}
|
||||
|
||||
}
|
||||
func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env vo.Env, idGen workflow.ConversationIDGenerator, meta *vo.CreateDynamicConversation) (int64, bool, error) {
|
||||
func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env vo.Env, idGen workflow.ConversationIDGenerator, meta *vo.CreateDynamicConversation) (int64, int64, bool, error) {
|
||||
if env == vo.Draft {
|
||||
|
||||
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
|
||||
@ -437,21 +438,25 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
appDynamicConversationDraft.Name.Eq(meta.Name),
|
||||
).First()
|
||||
if err == nil {
|
||||
return ret.ConversationID, true, nil
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, ret.ConversationID)
|
||||
if err != nil {
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
return ret.ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
cID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
cID, sID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
return 0, 0, false, err
|
||||
}
|
||||
|
||||
id, err := r.GenID(ctx)
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
err = r.query.AppDynamicConversationDraft.WithContext(ctx).Create(&model.AppDynamicConversationDraft{
|
||||
@ -463,10 +468,10 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
ConversationID: cID,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
return cID, false, nil
|
||||
return cID, sID, false, nil
|
||||
|
||||
} else if env == vo.Online {
|
||||
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
|
||||
@ -477,19 +482,23 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
appDynamicConversationOnline.Name.Eq(meta.Name),
|
||||
).First()
|
||||
if err == nil {
|
||||
return ret.ConversationID, true, nil
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, ret.ConversationID)
|
||||
if err != nil {
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
return ret.ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
cID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
cID, sID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
return 0, 0, false, err
|
||||
}
|
||||
id, err := r.GenID(ctx)
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
err = r.query.AppDynamicConversationOnline.WithContext(ctx).Create(&model.AppDynamicConversationOnline{
|
||||
@ -501,13 +510,13 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
|
||||
ConversationID: cID,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
return cID, false, nil
|
||||
return cID, sID, false, nil
|
||||
|
||||
} else {
|
||||
return 0, false, fmt.Errorf("unknown env %v", env)
|
||||
return 0, 0, false, fmt.Errorf("unknown env %v", env)
|
||||
}
|
||||
|
||||
}
|
||||
@ -554,24 +563,28 @@ func (r *RepositoryImpl) GetStaticConversationByTemplateID(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context, idGen workflow.ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, bool, error) {
|
||||
func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context, idGen workflow.ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, int64, bool, error) {
|
||||
cs, err := r.mGetDraftStaticConversation(ctx, meta.UserID, meta.ConnectorID, []int64{meta.TemplateID})
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
if len(cs) > 0 {
|
||||
return cs[0].ConversationID, true, nil
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, cs[0].ConversationID)
|
||||
if err != nil {
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
return cs[0].ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
conversationID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
conversationID, sectionID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
return 0, 0, false, err
|
||||
}
|
||||
|
||||
id, err := r.GenID(ctx)
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
object := &model.AppStaticConversationDraft{
|
||||
ID: id,
|
||||
@ -582,30 +595,34 @@ func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context,
|
||||
}
|
||||
err = r.query.AppStaticConversationDraft.WithContext(ctx).Create(object)
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
return conversationID, false, nil
|
||||
return conversationID, sectionID, false, nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context, idGen workflow.ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, bool, error) {
|
||||
func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context, idGen workflow.ConversationIDGenerator, meta *vo.CreateStaticConversation) (int64, int64, bool, error) {
|
||||
cs, err := r.mGetOnlineStaticConversation(ctx, meta.UserID, meta.ConnectorID, []int64{meta.TemplateID})
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
if len(cs) > 0 {
|
||||
return cs[0].ConversationID, true, nil
|
||||
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, cs[0].ConversationID)
|
||||
if err != nil {
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
return cs[0].ConversationID, cInfo.SectionID, true, nil
|
||||
}
|
||||
|
||||
conversationID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
conversationID, sectionID, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
return 0, 0, false, err
|
||||
}
|
||||
|
||||
id, err := r.GenID(ctx)
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
object := &model.AppStaticConversationOnline{
|
||||
ID: id,
|
||||
@ -616,10 +633,10 @@ func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context
|
||||
}
|
||||
err = r.query.AppStaticConversationOnline.WithContext(ctx).Create(object)
|
||||
if err != nil {
|
||||
return 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
|
||||
return conversationID, false, nil
|
||||
return conversationID, sectionID, false, nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) BatchCreateOnlineConversationTemplate(ctx context.Context, templates []*entity.ConversationTemplate, version string) error {
|
||||
|
||||
@ -337,16 +337,16 @@ func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo
|
||||
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error) {
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
|
||||
AppID: ptr.Of(appID),
|
||||
Name: ptr.Of(conversationName),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (int64, error) {
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (int64, int64, error) {
|
||||
return conversation.GetConversationManager().CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
AppID: appID,
|
||||
UserID: userID,
|
||||
@ -355,28 +355,28 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
|
||||
})
|
||||
|
||||
if existed {
|
||||
conversationID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: appID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
TemplateID: t.TemplateID,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
return conversationID, nil
|
||||
return conversationID, sectionID, nil
|
||||
}
|
||||
|
||||
conversationID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
AppID: appID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
Name: conversationName,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
return conversationID, nil
|
||||
return conversationID, sectionID, nil
|
||||
|
||||
}
|
||||
|
||||
@ -391,7 +391,7 @@ func (c *conversationImpl) UpdateConversation(ctx context.Context, env vo.Env, a
|
||||
}
|
||||
|
||||
if existed {
|
||||
newConversationID, err := conversation.GetConversationManager().CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
newConversationID, _, err := conversation.GetConversationManager().CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
AppID: appID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
@ -412,7 +412,7 @@ func (c *conversationImpl) UpdateConversation(ctx context.Context, env vo.Env, a
|
||||
return 0, fmt.Errorf("conversation name %v not found", conversationName)
|
||||
}
|
||||
|
||||
newConversationID, err := conversation.GetConversationManager().CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
newConversationID, _, err := conversation.GetConversationManager().CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
AppID: appID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
|
||||
@ -1002,6 +1002,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config vo.ExecuteConfig,
|
||||
agentID := config.AgentID
|
||||
appID := config.AppID
|
||||
userID := config.Operator
|
||||
sectionID := config.SectionID
|
||||
if sectionID == nil {
|
||||
logs.CtxWarnf(ctx, "SectionID is nil, skipping chat history")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if convID == nil || *convID == 0 {
|
||||
logs.CtxWarnf(ctx, "ConversationID is 0 or nil, skipping chat history")
|
||||
@ -1023,6 +1028,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config vo.ExecuteConfig,
|
||||
AppID: resolvedAppID,
|
||||
UserID: userID,
|
||||
Rounds: historyRounds,
|
||||
SectionID: *sectionID,
|
||||
}
|
||||
|
||||
manager := conversation.GetConversationManager()
|
||||
|
||||
Reference in New Issue
Block a user