fix(singleagent): v3/chat support customer variables (#2262)

This commit is contained in:
junwen-lee
2025-09-25 16:11:24 +08:00
committed by GitHub
parent d3b1e8cfd3
commit 0995536884
13 changed files with 155 additions and 5 deletions

View File

@ -19,6 +19,7 @@ package coze
import (
"context"
"encoding/json"
"errors"
"net/http"
"github.com/hertz-contrib/sse"
@ -90,6 +91,13 @@ func checkParams(_ context.Context, ar *run.AgentRunRequest) error {
func ChatV3(ctx context.Context, c *app.RequestContext) {
var err error
var req run.ChatV3Request
// Pre-process parameters field: convert JSON object to string if needed
if err = preprocessChatV3Parameters(c); err != nil {
invalidParamRequestResponse(c, err.Error())
return
}
err = c.BindAndValidate(&req)
if err != nil {
invalidParamRequestResponse(c, err.Error())
@ -144,3 +152,46 @@ func CancelChatApi(ctx context.Context, c *app.RequestContext) {
c.JSON(consts.StatusOK, resp)
}
// preprocessChatV3Parameters handles the conversion of parameters field from JSON object to string
func preprocessChatV3Parameters(c *app.RequestContext) error {
// Get the raw request body
body := c.Request.Body()
if len(body) == 0 {
return nil
}
// Parse the JSON body
var requestData map[string]interface{}
if err := json.Unmarshal(body, &requestData); err != nil {
return nil // If it's not valid JSON, let BindAndValidate handle the error
}
// Check if parameters field exists and is an object
if parametersValue, exists := requestData["parameters"]; exists {
// If parameters is already a string, no conversion needed
if _, isString := parametersValue.(string); isString {
return errors.New("parameters field should be an object, not a string")
}
// If parameters is an object, convert it to JSON string
if parametersObj, isObject := parametersValue.(map[string]interface{}); isObject {
parametersJSON, err := json.Marshal(parametersObj)
if err != nil {
return err
}
requestData["parameters"] = string(parametersJSON)
// Update the request body with the modified data
modifiedBody, err := json.Marshal(requestData)
if err != nil {
return err
}
// Replace the request body
c.Request.SetBody(modifiedBody)
}
}
return nil
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
// Code generated by thriftgo (0.4.1). DO NOT EDIT.
// Code generated by thriftgo (0.4.2). DO NOT EDIT.
package run
@ -22,7 +22,6 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/apache/thrift/lib/go/thrift"
"github.com/coze-dev/coze-studio/backend/api/model/base"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
@ -5153,6 +5152,7 @@ type ChatV3Request struct {
ConnectorID *int64 `thrift:"ConnectorID,12,optional" form:"connector_id" json:"connector_id,string,omitempty"`
// Specify shortcut instructions
ShortcutCommand *ShortcutCommandDetail `thrift:"ShortcutCommand,13,optional" form:"shortcut_command" json:"shortcut_command,omitempty"`
Parameters *string `thrift:"Parameters,14,optional" form:"parameters" json:"parameters,omitempty"`
}
func NewChatV3Request() *ChatV3Request {
@ -5251,6 +5251,15 @@ func (p *ChatV3Request) GetShortcutCommand() (v *ShortcutCommandDetail) {
return p.ShortcutCommand
}
var ChatV3Request_Parameters_DEFAULT string
func (p *ChatV3Request) GetParameters() (v string) {
if !p.IsSetParameters() {
return ChatV3Request_Parameters_DEFAULT
}
return *p.Parameters
}
var fieldIDToName_ChatV3Request = map[int16]string{
1: "BotID",
2: "ConversationID",
@ -5263,6 +5272,7 @@ var fieldIDToName_ChatV3Request = map[int16]string{
11: "ExtraParams",
12: "ConnectorID",
13: "ShortcutCommand",
14: "Parameters",
}
func (p *ChatV3Request) IsSetConversationID() bool {
@ -5301,6 +5311,10 @@ func (p *ChatV3Request) IsSetShortcutCommand() bool {
return p.ShortcutCommand != nil
}
func (p *ChatV3Request) IsSetParameters() bool {
return p.Parameters != nil
}
func (p *ChatV3Request) Read(iprot thrift.TProtocol) (err error) {
var fieldTypeId thrift.TType
var fieldId int16
@ -5411,6 +5425,14 @@ func (p *ChatV3Request) Read(iprot thrift.TProtocol) (err error) {
} else if err = iprot.Skip(fieldTypeId); err != nil {
goto SkipFieldError
}
case 14:
if fieldTypeId == thrift.STRING {
if err = p.ReadField14(iprot); err != nil {
goto ReadFieldError
}
} else if err = iprot.Skip(fieldTypeId); err != nil {
goto SkipFieldError
}
default:
if err = iprot.Skip(fieldTypeId); err != nil {
goto SkipFieldError
@ -5632,6 +5654,17 @@ func (p *ChatV3Request) ReadField13(iprot thrift.TProtocol) error {
p.ShortcutCommand = _field
return nil
}
func (p *ChatV3Request) ReadField14(iprot thrift.TProtocol) error {
var _field *string
if v, err := iprot.ReadString(); err != nil {
return err
} else {
_field = &v
}
p.Parameters = _field
return nil
}
func (p *ChatV3Request) Write(oprot thrift.TProtocol) (err error) {
var fieldId int16
@ -5683,6 +5716,10 @@ func (p *ChatV3Request) Write(oprot thrift.TProtocol) (err error) {
fieldId = 13
goto WriteFieldError
}
if err = p.writeField14(oprot); err != nil {
fieldId = 14
goto WriteFieldError
}
}
if err = oprot.WriteFieldStop(); err != nil {
goto WriteFieldStopError
@ -5936,6 +5973,24 @@ WriteFieldBeginError:
WriteFieldEndError:
return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err)
}
func (p *ChatV3Request) writeField14(oprot thrift.TProtocol) (err error) {
if p.IsSetParameters() {
if err = oprot.WriteFieldBegin("Parameters", thrift.STRING, 14); err != nil {
goto WriteFieldBeginError
}
if err := oprot.WriteString(*p.Parameters); err != nil {
return err
}
if err = oprot.WriteFieldEnd(); err != nil {
goto WriteFieldEndError
}
}
return nil
WriteFieldBeginError:
return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err)
WriteFieldEndError:
return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err)
}
func (p *ChatV3Request) String() string {
if p == nil {

View File

@ -113,6 +113,8 @@ type ExecuteRequest struct {
ResumeInfo *InterruptInfo
PreCallTools []*agentrun.ToolsRetriever
CustomVariables map[string]string
ConversationID int64
}

View File

@ -148,6 +148,10 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a
return nil, err
}
displayContent := a.buildDisplayContent(ctx, ar)
chatflowParameters, err := parseChatflowParameters(ctx, ar)
if err != nil {
return nil, err
}
arm := &entity.AgentRunMeta{
ConversationID: ptr.From(ar.ConversationID),
AgentID: ar.BotID,
@ -169,9 +173,20 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a
CustomVariables: ar.CustomVariables,
CozeUID: conversationData.CreatorID,
AdditionalMessages: filterMultiAdditionalMessages,
ChatflowParameters: chatflowParameters,
}
return arm, nil
}
func parseChatflowParameters(ctx context.Context, ar *run.ChatV3Request) (map[string]any, error) {
parameters := make(map[string]any)
if ar.Parameters != nil {
if err := json.Unmarshal([]byte(*ar.Parameters), &parameters); err != nil {
return nil, errors.New("parameters field should be an object, not a string")
}
return parameters,nil
}
return parameters,nil
}
func (a *OpenapiAgentRunApplication) buildTools(ctx context.Context, shortcmd *run.ShortcutCommandDetail) ([]*entity.Tool, error) {
var ts []*entity.Tool

View File

@ -42,6 +42,7 @@ type AgentRuntime struct {
SpaceID int64
ConnectorID int64
PreRetrieveTools []*agentrun.Tool
CustomVariables map[string]string
HistoryMsg []*schema.Message
Input *schema.Message

View File

@ -62,6 +62,7 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, agentRuntim
Input: agentRuntime.Input,
History: agentRuntime.HistoryMsg,
UserID: agentRuntime.UserID,
CustomVariables: agentRuntime.CustomVariables,
PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever {
return &agentrun.ToolsRetriever{
PluginID: tool.PluginID,

View File

@ -43,6 +43,9 @@ type Config struct {
ModelFactory chatmodel.Factory
CPStore compose.CheckPointStore
CustomVariables map[string]string
ConversationID int64
}
@ -71,6 +74,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
if err != nil {
return nil, err
}
if conf.CustomVariables != nil {
for k,v := range conf.CustomVariables {
avs[k] = v
}
}
promptVars := &promptVariables{
Agent: conf.Agent,

View File

@ -22,6 +22,7 @@ import (
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/dal/model"
@ -105,6 +106,8 @@ func (sa *SingleAgentVersionDAO) singleAgentVersionPo2Do(po *model.SingleAgentVe
Database: po.DatabaseConfig,
ShortcutCommand: po.ShortcutCommand,
Version: po.Version,
BotMode: bot_common.BotMode(po.BotMode),
LayoutInfo: po.LayoutInfo,
},
}
}
@ -131,5 +134,7 @@ func (sa *SingleAgentVersionDAO) singleAgentVersionDo2Po(do *entity.SingleAgent)
VariablesMetaID: do.VariablesMetaID,
DatabaseConfig: do.Database,
ShortcutCommand: do.ShortcutCommand,
BotMode: int32(do.BotMode),
LayoutInfo: do.LayoutInfo,
}
}

View File

@ -112,6 +112,8 @@ func (s *singleAgentImpl) StreamExecute(ctx context.Context, req *entity.Execute
ModelFactory: s.ModelFactory,
CPStore: s.CPStore,
CustomVariables: req.CustomVariables,
ConversationID: req.ConversationID,
}
rn, err := agentflow.BuildAgent(ctx, conf)

View File

@ -125,6 +125,7 @@ type AgentRunMeta struct {
Version string `json:"version"`
Ext map[string]string `json:"ext"`
AdditionalMessages []*AdditionalMessage `json:"additional_messages"`
ChatflowParameters map[string]any `json:"chatflow_parameters"`
}
type AdditionalMessage struct {

View File

@ -78,9 +78,15 @@ func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX)
executeConfig.RoundID = &art.RunRecord.ID
executeConfig.UserMessage = transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0]
executeConfig.MaxHistoryRounds = ptr.Of(getAgentHistoryRounds(art.GetAgentInfo()))
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, map[string]any{
chatInput := map[string]any{
"USER_INPUT": concatWfInput(art),
})
}
if art.GetRunMeta().ChatflowParameters != nil {
for k,v := range art.GetRunMeta().ChatflowParameters {
chatInput[k] = v
}
}
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, chatInput)
}
if err != nil {
return err

View File

@ -52,6 +52,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I
ConversationId: art.GetRunMeta().ConversationID,
ConnectorID: art.GetRunMeta().ConnectorID,
PreRetrieveTools: art.GetRunMeta().PreRetrieveTools,
CustomVariables: art.GetRunMeta().CustomVariables,
Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0],
HistoryMsg: transMessageToSchemaMessage(ctx, historyPairs(art.GetHistory()), imagex),
ResumeInfo: parseResumeInfo(ctx, art.GetHistory()),