From 099553688476fb29a65d75838c33a3be8b95e804 Mon Sep 17 00:00:00 2001 From: junwen-lee Date: Thu, 25 Sep 2025 16:11:24 +0800 Subject: [PATCH] fix(singleagent): v3/chat support customer variables (#2262) --- backend/api/handler/coze/agent_run_service.go | 51 ++++++++++++++++ backend/api/model/conversation/run/run.go | 59 ++++++++++++++++++- .../crossdomain/singleagent/single_agent.go | 2 + .../conversation/openapi_agent_run.go | 15 +++++ .../contract/agent/single_agent.go | 1 + .../impl/singleagent/single_agent.go | 1 + .../internal/agentflow/agent_flow_builder.go | 8 +++ .../internal/dal/single_agent_version.go | 5 ++ .../singleagent/service/single_agent_impl.go | 2 + .../agentrun/entity/run_record.go | 1 + .../agentrun/internal/chatflow_run.go | 10 +++- .../agentrun/internal/singleagent_run.go | 1 + idl/conversation/run.thrift | 4 +- 13 files changed, 155 insertions(+), 5 deletions(-) diff --git a/backend/api/handler/coze/agent_run_service.go b/backend/api/handler/coze/agent_run_service.go index dac4a8e73..aedcc6668 100644 --- a/backend/api/handler/coze/agent_run_service.go +++ b/backend/api/handler/coze/agent_run_service.go @@ -19,6 +19,7 @@ package coze import ( "context" "encoding/json" + "errors" "net/http" "github.com/hertz-contrib/sse" @@ -90,6 +91,13 @@ func checkParams(_ context.Context, ar *run.AgentRunRequest) error { func ChatV3(ctx context.Context, c *app.RequestContext) { var err error var req run.ChatV3Request + + // Pre-process parameters field: convert JSON object to string if needed + if err = preprocessChatV3Parameters(c); err != nil { + invalidParamRequestResponse(c, err.Error()) + return + } + err = c.BindAndValidate(&req) if err != nil { invalidParamRequestResponse(c, err.Error()) @@ -144,3 +152,46 @@ func CancelChatApi(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, resp) } + +// preprocessChatV3Parameters handles the conversion of parameters field from JSON object to string +func preprocessChatV3Parameters(c *app.RequestContext) error { + // Get the raw request body + body := c.Request.Body() + if len(body) == 0 { + return nil + } + + // Parse the JSON body + var requestData map[string]interface{} + if err := json.Unmarshal(body, &requestData); err != nil { + return nil // If it's not valid JSON, let BindAndValidate handle the error + } + + // Check if parameters field exists and is an object + if parametersValue, exists := requestData["parameters"]; exists { + // If parameters is already a string, no conversion needed + if _, isString := parametersValue.(string); isString { + return errors.New("parameters field should be an object, not a string") + } + + // If parameters is an object, convert it to JSON string + if parametersObj, isObject := parametersValue.(map[string]interface{}); isObject { + parametersJSON, err := json.Marshal(parametersObj) + if err != nil { + return err + } + requestData["parameters"] = string(parametersJSON) + + // Update the request body with the modified data + modifiedBody, err := json.Marshal(requestData) + if err != nil { + return err + } + + // Replace the request body + c.Request.SetBody(modifiedBody) + } + } + + return nil +} diff --git a/backend/api/model/conversation/run/run.go b/backend/api/model/conversation/run/run.go index b58e6a9ab..9a52759b1 100644 --- a/backend/api/model/conversation/run/run.go +++ b/backend/api/model/conversation/run/run.go @@ -14,7 +14,7 @@ * limitations under the License. */ -// Code generated by thriftgo (0.4.1). DO NOT EDIT. +// Code generated by thriftgo (0.4.2). DO NOT EDIT. package run @@ -22,7 +22,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - "github.com/apache/thrift/lib/go/thrift" "github.com/coze-dev/coze-studio/backend/api/model/base" "github.com/coze-dev/coze-studio/backend/api/model/conversation/common" @@ -5153,6 +5152,7 @@ type ChatV3Request struct { ConnectorID *int64 `thrift:"ConnectorID,12,optional" form:"connector_id" json:"connector_id,string,omitempty"` // Specify shortcut instructions ShortcutCommand *ShortcutCommandDetail `thrift:"ShortcutCommand,13,optional" form:"shortcut_command" json:"shortcut_command,omitempty"` + Parameters *string `thrift:"Parameters,14,optional" form:"parameters" json:"parameters,omitempty"` } func NewChatV3Request() *ChatV3Request { @@ -5251,6 +5251,15 @@ func (p *ChatV3Request) GetShortcutCommand() (v *ShortcutCommandDetail) { return p.ShortcutCommand } +var ChatV3Request_Parameters_DEFAULT string + +func (p *ChatV3Request) GetParameters() (v string) { + if !p.IsSetParameters() { + return ChatV3Request_Parameters_DEFAULT + } + return *p.Parameters +} + var fieldIDToName_ChatV3Request = map[int16]string{ 1: "BotID", 2: "ConversationID", @@ -5263,6 +5272,7 @@ var fieldIDToName_ChatV3Request = map[int16]string{ 11: "ExtraParams", 12: "ConnectorID", 13: "ShortcutCommand", + 14: "Parameters", } func (p *ChatV3Request) IsSetConversationID() bool { @@ -5301,6 +5311,10 @@ func (p *ChatV3Request) IsSetShortcutCommand() bool { return p.ShortcutCommand != nil } +func (p *ChatV3Request) IsSetParameters() bool { + return p.Parameters != nil +} + func (p *ChatV3Request) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -5411,6 +5425,14 @@ func (p *ChatV3Request) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 14: + if fieldTypeId == thrift.STRING { + if err = p.ReadField14(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -5632,6 +5654,17 @@ func (p *ChatV3Request) ReadField13(iprot thrift.TProtocol) error { p.ShortcutCommand = _field return nil } +func (p *ChatV3Request) ReadField14(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Parameters = _field + return nil +} func (p *ChatV3Request) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -5683,6 +5716,10 @@ func (p *ChatV3Request) Write(oprot thrift.TProtocol) (err error) { fieldId = 13 goto WriteFieldError } + if err = p.writeField14(oprot); err != nil { + fieldId = 14 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -5936,6 +5973,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) } +func (p *ChatV3Request) writeField14(oprot thrift.TProtocol) (err error) { + if p.IsSetParameters() { + if err = oprot.WriteFieldBegin("Parameters", thrift.STRING, 14); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Parameters); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) +} func (p *ChatV3Request) String() string { if p == nil { diff --git a/backend/api/model/crossdomain/singleagent/single_agent.go b/backend/api/model/crossdomain/singleagent/single_agent.go index ad156776a..e7391368f 100644 --- a/backend/api/model/crossdomain/singleagent/single_agent.go +++ b/backend/api/model/crossdomain/singleagent/single_agent.go @@ -113,6 +113,8 @@ type ExecuteRequest struct { ResumeInfo *InterruptInfo PreCallTools []*agentrun.ToolsRetriever + CustomVariables map[string]string + ConversationID int64 } diff --git a/backend/application/conversation/openapi_agent_run.go b/backend/application/conversation/openapi_agent_run.go index ad73612c4..3e4bc230a 100644 --- a/backend/application/conversation/openapi_agent_run.go +++ b/backend/application/conversation/openapi_agent_run.go @@ -148,6 +148,10 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a return nil, err } displayContent := a.buildDisplayContent(ctx, ar) + chatflowParameters, err := parseChatflowParameters(ctx, ar) + if err != nil { + return nil, err + } arm := &entity.AgentRunMeta{ ConversationID: ptr.From(ar.ConversationID), AgentID: ar.BotID, @@ -169,9 +173,20 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a CustomVariables: ar.CustomVariables, CozeUID: conversationData.CreatorID, AdditionalMessages: filterMultiAdditionalMessages, + ChatflowParameters: chatflowParameters, } return arm, nil } +func parseChatflowParameters(ctx context.Context, ar *run.ChatV3Request) (map[string]any, error) { + parameters := make(map[string]any) + if ar.Parameters != nil { + if err := json.Unmarshal([]byte(*ar.Parameters), ¶meters); err != nil { + return nil, errors.New("parameters field should be an object, not a string") + } + return parameters,nil + } + return parameters,nil +} func (a *OpenapiAgentRunApplication) buildTools(ctx context.Context, shortcmd *run.ShortcutCommandDetail) ([]*entity.Tool, error) { var ts []*entity.Tool diff --git a/backend/crossdomain/contract/agent/single_agent.go b/backend/crossdomain/contract/agent/single_agent.go index d1396bdfe..275591eb4 100644 --- a/backend/crossdomain/contract/agent/single_agent.go +++ b/backend/crossdomain/contract/agent/single_agent.go @@ -42,6 +42,7 @@ type AgentRuntime struct { SpaceID int64 ConnectorID int64 PreRetrieveTools []*agentrun.Tool + CustomVariables map[string]string HistoryMsg []*schema.Message Input *schema.Message diff --git a/backend/crossdomain/impl/singleagent/single_agent.go b/backend/crossdomain/impl/singleagent/single_agent.go index 6b749413c..a0580e82f 100644 --- a/backend/crossdomain/impl/singleagent/single_agent.go +++ b/backend/crossdomain/impl/singleagent/single_agent.go @@ -62,6 +62,7 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, agentRuntim Input: agentRuntime.Input, History: agentRuntime.HistoryMsg, UserID: agentRuntime.UserID, + CustomVariables: agentRuntime.CustomVariables, PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever { return &agentrun.ToolsRetriever{ PluginID: tool.PluginID, diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go index 889f20dcd..4eb80edae 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go @@ -43,6 +43,9 @@ type Config struct { ModelFactory chatmodel.Factory CPStore compose.CheckPointStore + CustomVariables map[string]string + + ConversationID int64 } @@ -71,6 +74,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) { if err != nil { return nil, err } + if conf.CustomVariables != nil { + for k,v := range conf.CustomVariables { + avs[k] = v + } + } promptVars := &promptVariables{ Agent: conf.Agent, diff --git a/backend/domain/agent/singleagent/internal/dal/single_agent_version.go b/backend/domain/agent/singleagent/internal/dal/single_agent_version.go index 9f1d9fd67..3812e49be 100644 --- a/backend/domain/agent/singleagent/internal/dal/single_agent_version.go +++ b/backend/domain/agent/singleagent/internal/dal/single_agent_version.go @@ -22,6 +22,7 @@ import ( "gorm.io/gorm" + "github.com/coze-dev/coze-studio/backend/api/model/app/bot_common" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/dal/model" @@ -105,6 +106,8 @@ func (sa *SingleAgentVersionDAO) singleAgentVersionPo2Do(po *model.SingleAgentVe Database: po.DatabaseConfig, ShortcutCommand: po.ShortcutCommand, Version: po.Version, + BotMode: bot_common.BotMode(po.BotMode), + LayoutInfo: po.LayoutInfo, }, } } @@ -131,5 +134,7 @@ func (sa *SingleAgentVersionDAO) singleAgentVersionDo2Po(do *entity.SingleAgent) VariablesMetaID: do.VariablesMetaID, DatabaseConfig: do.Database, ShortcutCommand: do.ShortcutCommand, + BotMode: int32(do.BotMode), + LayoutInfo: do.LayoutInfo, } } diff --git a/backend/domain/agent/singleagent/service/single_agent_impl.go b/backend/domain/agent/singleagent/service/single_agent_impl.go index 8e1da737c..f37da908b 100644 --- a/backend/domain/agent/singleagent/service/single_agent_impl.go +++ b/backend/domain/agent/singleagent/service/single_agent_impl.go @@ -112,6 +112,8 @@ func (s *singleAgentImpl) StreamExecute(ctx context.Context, req *entity.Execute ModelFactory: s.ModelFactory, CPStore: s.CPStore, + CustomVariables: req.CustomVariables, + ConversationID: req.ConversationID, } rn, err := agentflow.BuildAgent(ctx, conf) diff --git a/backend/domain/conversation/agentrun/entity/run_record.go b/backend/domain/conversation/agentrun/entity/run_record.go index 077e0531b..870570ddb 100644 --- a/backend/domain/conversation/agentrun/entity/run_record.go +++ b/backend/domain/conversation/agentrun/entity/run_record.go @@ -125,6 +125,7 @@ type AgentRunMeta struct { Version string `json:"version"` Ext map[string]string `json:"ext"` AdditionalMessages []*AdditionalMessage `json:"additional_messages"` + ChatflowParameters map[string]any `json:"chatflow_parameters"` } type AdditionalMessage struct { diff --git a/backend/domain/conversation/agentrun/internal/chatflow_run.go b/backend/domain/conversation/agentrun/internal/chatflow_run.go index fbbf7a882..ba5bd4ffc 100644 --- a/backend/domain/conversation/agentrun/internal/chatflow_run.go +++ b/backend/domain/conversation/agentrun/internal/chatflow_run.go @@ -78,9 +78,15 @@ func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) executeConfig.RoundID = &art.RunRecord.ID executeConfig.UserMessage = transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0] executeConfig.MaxHistoryRounds = ptr.Of(getAgentHistoryRounds(art.GetAgentInfo())) - wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, map[string]any{ + chatInput := map[string]any{ "USER_INPUT": concatWfInput(art), - }) + } + if art.GetRunMeta().ChatflowParameters != nil { + for k,v := range art.GetRunMeta().ChatflowParameters { + chatInput[k] = v + } + } + wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, chatInput) } if err != nil { return err diff --git a/backend/domain/conversation/agentrun/internal/singleagent_run.go b/backend/domain/conversation/agentrun/internal/singleagent_run.go index dbcac221d..5c931b0f4 100644 --- a/backend/domain/conversation/agentrun/internal/singleagent_run.go +++ b/backend/domain/conversation/agentrun/internal/singleagent_run.go @@ -52,6 +52,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I ConversationId: art.GetRunMeta().ConversationID, ConnectorID: art.GetRunMeta().ConnectorID, PreRetrieveTools: art.GetRunMeta().PreRetrieveTools, + CustomVariables: art.GetRunMeta().CustomVariables, Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0], HistoryMsg: transMessageToSchemaMessage(ctx, historyPairs(art.GetHistory()), imagex), ResumeInfo: parseResumeInfo(ctx, art.GetHistory()), diff --git a/idl/conversation/run.thrift b/idl/conversation/run.thrift index 6384985b9..ed907f695 100644 --- a/idl/conversation/run.thrift +++ b/idl/conversation/run.thrift @@ -146,7 +146,6 @@ struct ShortcutCommandDetail { 2: map parameters // Key = parameter name value = value object_string JSON String after object array serialization } - struct ChatV3Request { 1: required i64 BotID (api.body = "bot_id",api.js_conv='true'), //agent_id 2: optional i64 ConversationID (api.query = "conversation_id", api.js_conv='true'), //conversation_id @@ -159,6 +158,9 @@ struct ChatV3Request { 11:optional map ExtraParams (api.body = "extra_params") // Pass parameters to plugin/workflow etc downstream 12:optional i64 ConnectorID (api.body="connector_id", api.js_conv='true') // Manually specify channel id chat. Currently only supports websdk (= 999) 13:optional ShortcutCommandDetail ShortcutCommand (api.body="shortcut_command") // Specify shortcut instructions + 14: optional string Parameters (api.body="parameters") + + } struct ChatV3MessageDetail {