fix(singleagent): v3/chat support customer variables (#2262)
This commit is contained in:
@ -19,6 +19,7 @@ package coze
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/hertz-contrib/sse"
|
||||
@ -90,6 +91,13 @@ func checkParams(_ context.Context, ar *run.AgentRunRequest) error {
|
||||
func ChatV3(ctx context.Context, c *app.RequestContext) {
|
||||
var err error
|
||||
var req run.ChatV3Request
|
||||
|
||||
// Pre-process parameters field: convert JSON object to string if needed
|
||||
if err = preprocessChatV3Parameters(c); err != nil {
|
||||
invalidParamRequestResponse(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = c.BindAndValidate(&req)
|
||||
if err != nil {
|
||||
invalidParamRequestResponse(c, err.Error())
|
||||
@ -144,3 +152,46 @@ func CancelChatApi(ctx context.Context, c *app.RequestContext) {
|
||||
|
||||
c.JSON(consts.StatusOK, resp)
|
||||
}
|
||||
|
||||
// preprocessChatV3Parameters handles the conversion of parameters field from JSON object to string
|
||||
func preprocessChatV3Parameters(c *app.RequestContext) error {
|
||||
// Get the raw request body
|
||||
body := c.Request.Body()
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the JSON body
|
||||
var requestData map[string]interface{}
|
||||
if err := json.Unmarshal(body, &requestData); err != nil {
|
||||
return nil // If it's not valid JSON, let BindAndValidate handle the error
|
||||
}
|
||||
|
||||
// Check if parameters field exists and is an object
|
||||
if parametersValue, exists := requestData["parameters"]; exists {
|
||||
// If parameters is already a string, no conversion needed
|
||||
if _, isString := parametersValue.(string); isString {
|
||||
return errors.New("parameters field should be an object, not a string")
|
||||
}
|
||||
|
||||
// If parameters is an object, convert it to JSON string
|
||||
if parametersObj, isObject := parametersValue.(map[string]interface{}); isObject {
|
||||
parametersJSON, err := json.Marshal(parametersObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requestData["parameters"] = string(parametersJSON)
|
||||
|
||||
// Update the request body with the modified data
|
||||
modifiedBody, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Replace the request body
|
||||
c.Request.SetBody(modifiedBody)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by thriftgo (0.4.1). DO NOT EDIT.
|
||||
// Code generated by thriftgo (0.4.2). DO NOT EDIT.
|
||||
|
||||
package run
|
||||
|
||||
@ -22,7 +22,6 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
|
||||
"github.com/apache/thrift/lib/go/thrift"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/base"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
@ -5153,6 +5152,7 @@ type ChatV3Request struct {
|
||||
ConnectorID *int64 `thrift:"ConnectorID,12,optional" form:"connector_id" json:"connector_id,string,omitempty"`
|
||||
// Specify shortcut instructions
|
||||
ShortcutCommand *ShortcutCommandDetail `thrift:"ShortcutCommand,13,optional" form:"shortcut_command" json:"shortcut_command,omitempty"`
|
||||
Parameters *string `thrift:"Parameters,14,optional" form:"parameters" json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
func NewChatV3Request() *ChatV3Request {
|
||||
@ -5251,6 +5251,15 @@ func (p *ChatV3Request) GetShortcutCommand() (v *ShortcutCommandDetail) {
|
||||
return p.ShortcutCommand
|
||||
}
|
||||
|
||||
var ChatV3Request_Parameters_DEFAULT string
|
||||
|
||||
func (p *ChatV3Request) GetParameters() (v string) {
|
||||
if !p.IsSetParameters() {
|
||||
return ChatV3Request_Parameters_DEFAULT
|
||||
}
|
||||
return *p.Parameters
|
||||
}
|
||||
|
||||
var fieldIDToName_ChatV3Request = map[int16]string{
|
||||
1: "BotID",
|
||||
2: "ConversationID",
|
||||
@ -5263,6 +5272,7 @@ var fieldIDToName_ChatV3Request = map[int16]string{
|
||||
11: "ExtraParams",
|
||||
12: "ConnectorID",
|
||||
13: "ShortcutCommand",
|
||||
14: "Parameters",
|
||||
}
|
||||
|
||||
func (p *ChatV3Request) IsSetConversationID() bool {
|
||||
@ -5301,6 +5311,10 @@ func (p *ChatV3Request) IsSetShortcutCommand() bool {
|
||||
return p.ShortcutCommand != nil
|
||||
}
|
||||
|
||||
func (p *ChatV3Request) IsSetParameters() bool {
|
||||
return p.Parameters != nil
|
||||
}
|
||||
|
||||
func (p *ChatV3Request) Read(iprot thrift.TProtocol) (err error) {
|
||||
var fieldTypeId thrift.TType
|
||||
var fieldId int16
|
||||
@ -5411,6 +5425,14 @@ func (p *ChatV3Request) Read(iprot thrift.TProtocol) (err error) {
|
||||
} else if err = iprot.Skip(fieldTypeId); err != nil {
|
||||
goto SkipFieldError
|
||||
}
|
||||
case 14:
|
||||
if fieldTypeId == thrift.STRING {
|
||||
if err = p.ReadField14(iprot); err != nil {
|
||||
goto ReadFieldError
|
||||
}
|
||||
} else if err = iprot.Skip(fieldTypeId); err != nil {
|
||||
goto SkipFieldError
|
||||
}
|
||||
default:
|
||||
if err = iprot.Skip(fieldTypeId); err != nil {
|
||||
goto SkipFieldError
|
||||
@ -5632,6 +5654,17 @@ func (p *ChatV3Request) ReadField13(iprot thrift.TProtocol) error {
|
||||
p.ShortcutCommand = _field
|
||||
return nil
|
||||
}
|
||||
func (p *ChatV3Request) ReadField14(iprot thrift.TProtocol) error {
|
||||
|
||||
var _field *string
|
||||
if v, err := iprot.ReadString(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
_field = &v
|
||||
}
|
||||
p.Parameters = _field
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ChatV3Request) Write(oprot thrift.TProtocol) (err error) {
|
||||
var fieldId int16
|
||||
@ -5683,6 +5716,10 @@ func (p *ChatV3Request) Write(oprot thrift.TProtocol) (err error) {
|
||||
fieldId = 13
|
||||
goto WriteFieldError
|
||||
}
|
||||
if err = p.writeField14(oprot); err != nil {
|
||||
fieldId = 14
|
||||
goto WriteFieldError
|
||||
}
|
||||
}
|
||||
if err = oprot.WriteFieldStop(); err != nil {
|
||||
goto WriteFieldStopError
|
||||
@ -5936,6 +5973,24 @@ WriteFieldBeginError:
|
||||
WriteFieldEndError:
|
||||
return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err)
|
||||
}
|
||||
func (p *ChatV3Request) writeField14(oprot thrift.TProtocol) (err error) {
|
||||
if p.IsSetParameters() {
|
||||
if err = oprot.WriteFieldBegin("Parameters", thrift.STRING, 14); err != nil {
|
||||
goto WriteFieldBeginError
|
||||
}
|
||||
if err := oprot.WriteString(*p.Parameters); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = oprot.WriteFieldEnd(); err != nil {
|
||||
goto WriteFieldEndError
|
||||
}
|
||||
}
|
||||
return nil
|
||||
WriteFieldBeginError:
|
||||
return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err)
|
||||
WriteFieldEndError:
|
||||
return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err)
|
||||
}
|
||||
|
||||
func (p *ChatV3Request) String() string {
|
||||
if p == nil {
|
||||
|
||||
@ -113,6 +113,8 @@ type ExecuteRequest struct {
|
||||
ResumeInfo *InterruptInfo
|
||||
PreCallTools []*agentrun.ToolsRetriever
|
||||
|
||||
CustomVariables map[string]string
|
||||
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
|
||||
@ -148,6 +148,10 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a
|
||||
return nil, err
|
||||
}
|
||||
displayContent := a.buildDisplayContent(ctx, ar)
|
||||
chatflowParameters, err := parseChatflowParameters(ctx, ar)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arm := &entity.AgentRunMeta{
|
||||
ConversationID: ptr.From(ar.ConversationID),
|
||||
AgentID: ar.BotID,
|
||||
@ -169,9 +173,20 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a
|
||||
CustomVariables: ar.CustomVariables,
|
||||
CozeUID: conversationData.CreatorID,
|
||||
AdditionalMessages: filterMultiAdditionalMessages,
|
||||
ChatflowParameters: chatflowParameters,
|
||||
}
|
||||
return arm, nil
|
||||
}
|
||||
func parseChatflowParameters(ctx context.Context, ar *run.ChatV3Request) (map[string]any, error) {
|
||||
parameters := make(map[string]any)
|
||||
if ar.Parameters != nil {
|
||||
if err := json.Unmarshal([]byte(*ar.Parameters), ¶meters); err != nil {
|
||||
return nil, errors.New("parameters field should be an object, not a string")
|
||||
}
|
||||
return parameters,nil
|
||||
}
|
||||
return parameters,nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) buildTools(ctx context.Context, shortcmd *run.ShortcutCommandDetail) ([]*entity.Tool, error) {
|
||||
var ts []*entity.Tool
|
||||
|
||||
@ -42,6 +42,7 @@ type AgentRuntime struct {
|
||||
SpaceID int64
|
||||
ConnectorID int64
|
||||
PreRetrieveTools []*agentrun.Tool
|
||||
CustomVariables map[string]string
|
||||
|
||||
HistoryMsg []*schema.Message
|
||||
Input *schema.Message
|
||||
|
||||
@ -62,6 +62,7 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, agentRuntim
|
||||
Input: agentRuntime.Input,
|
||||
History: agentRuntime.HistoryMsg,
|
||||
UserID: agentRuntime.UserID,
|
||||
CustomVariables: agentRuntime.CustomVariables,
|
||||
PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever {
|
||||
return &agentrun.ToolsRetriever{
|
||||
PluginID: tool.PluginID,
|
||||
|
||||
@ -43,6 +43,9 @@ type Config struct {
|
||||
ModelFactory chatmodel.Factory
|
||||
CPStore compose.CheckPointStore
|
||||
|
||||
CustomVariables map[string]string
|
||||
|
||||
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
@ -71,6 +74,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf.CustomVariables != nil {
|
||||
for k,v := range conf.CustomVariables {
|
||||
avs[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
promptVars := &promptVariables{
|
||||
Agent: conf.Agent,
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/dal/model"
|
||||
@ -105,6 +106,8 @@ func (sa *SingleAgentVersionDAO) singleAgentVersionPo2Do(po *model.SingleAgentVe
|
||||
Database: po.DatabaseConfig,
|
||||
ShortcutCommand: po.ShortcutCommand,
|
||||
Version: po.Version,
|
||||
BotMode: bot_common.BotMode(po.BotMode),
|
||||
LayoutInfo: po.LayoutInfo,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -131,5 +134,7 @@ func (sa *SingleAgentVersionDAO) singleAgentVersionDo2Po(do *entity.SingleAgent)
|
||||
VariablesMetaID: do.VariablesMetaID,
|
||||
DatabaseConfig: do.Database,
|
||||
ShortcutCommand: do.ShortcutCommand,
|
||||
BotMode: int32(do.BotMode),
|
||||
LayoutInfo: do.LayoutInfo,
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,6 +112,8 @@ func (s *singleAgentImpl) StreamExecute(ctx context.Context, req *entity.Execute
|
||||
ModelFactory: s.ModelFactory,
|
||||
CPStore: s.CPStore,
|
||||
|
||||
CustomVariables: req.CustomVariables,
|
||||
|
||||
ConversationID: req.ConversationID,
|
||||
}
|
||||
rn, err := agentflow.BuildAgent(ctx, conf)
|
||||
|
||||
@ -125,6 +125,7 @@ type AgentRunMeta struct {
|
||||
Version string `json:"version"`
|
||||
Ext map[string]string `json:"ext"`
|
||||
AdditionalMessages []*AdditionalMessage `json:"additional_messages"`
|
||||
ChatflowParameters map[string]any `json:"chatflow_parameters"`
|
||||
}
|
||||
|
||||
type AdditionalMessage struct {
|
||||
|
||||
@ -78,9 +78,15 @@ func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX)
|
||||
executeConfig.RoundID = &art.RunRecord.ID
|
||||
executeConfig.UserMessage = transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0]
|
||||
executeConfig.MaxHistoryRounds = ptr.Of(getAgentHistoryRounds(art.GetAgentInfo()))
|
||||
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, map[string]any{
|
||||
chatInput := map[string]any{
|
||||
"USER_INPUT": concatWfInput(art),
|
||||
})
|
||||
}
|
||||
if art.GetRunMeta().ChatflowParameters != nil {
|
||||
for k,v := range art.GetRunMeta().ChatflowParameters {
|
||||
chatInput[k] = v
|
||||
}
|
||||
}
|
||||
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, chatInput)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@ -52,6 +52,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I
|
||||
ConversationId: art.GetRunMeta().ConversationID,
|
||||
ConnectorID: art.GetRunMeta().ConnectorID,
|
||||
PreRetrieveTools: art.GetRunMeta().PreRetrieveTools,
|
||||
CustomVariables: art.GetRunMeta().CustomVariables,
|
||||
Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0],
|
||||
HistoryMsg: transMessageToSchemaMessage(ctx, historyPairs(art.GetHistory()), imagex),
|
||||
ResumeInfo: parseResumeInfo(ctx, art.GetHistory()),
|
||||
|
||||
@ -146,7 +146,6 @@ struct ShortcutCommandDetail {
|
||||
2: map<string,string> parameters // Key = parameter name value = value object_string JSON String after object array serialization
|
||||
}
|
||||
|
||||
|
||||
struct ChatV3Request {
|
||||
1: required i64 BotID (api.body = "bot_id",api.js_conv='true'), //agent_id
|
||||
2: optional i64 ConversationID (api.query = "conversation_id", api.js_conv='true'), //conversation_id
|
||||
@ -159,6 +158,9 @@ struct ChatV3Request {
|
||||
11:optional map<string, string> ExtraParams (api.body = "extra_params") // Pass parameters to plugin/workflow etc downstream
|
||||
12:optional i64 ConnectorID (api.body="connector_id", api.js_conv='true') // Manually specify channel id chat. Currently only supports websdk (= 999)
|
||||
13:optional ShortcutCommandDetail ShortcutCommand (api.body="shortcut_command") // Specify shortcut instructions
|
||||
14: optional string Parameters (api.body="parameters")
|
||||
|
||||
|
||||
}
|
||||
|
||||
struct ChatV3MessageDetail {
|
||||
|
||||
Reference in New Issue
Block a user