feat(plugin): abstract tool invocation to support http, custom plugins (#2227)
This commit is contained in:
@ -19,7 +19,6 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -31,6 +30,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func AccessLogMW() app.HandlerFunc {
|
||||
|
||||
@ -21,7 +21,9 @@ import "github.com/getkin/kin-openapi/openapi3"
|
||||
type PluginType string
|
||||
|
||||
const (
|
||||
PluginTypeOfCloud PluginType = "openapi"
|
||||
PluginTypeOfCloud PluginType = "openapi"
|
||||
PluginTypeOfMCP PluginType = "coze-studio-mcp"
|
||||
PluginTypeOfCustom PluginType = "coze-studio-custom"
|
||||
)
|
||||
|
||||
type AuthzType string
|
||||
|
||||
@ -89,6 +89,8 @@ var apiAssistTypes = map[common.AssistParameterType]APIFileAssistType{
|
||||
common.AssistParameterType_TXT: AssistTypeTXT,
|
||||
}
|
||||
|
||||
// TODO(fanlv): move to other package
|
||||
|
||||
func ToAPIAssistType(typ common.AssistParameterType) (APIFileAssistType, bool) {
|
||||
_typ, ok := apiAssistTypes[typ]
|
||||
return _typ, ok
|
||||
|
||||
@ -446,3 +446,28 @@ func disabledParam(schemaVal *openapi3.Schema) bool {
|
||||
|
||||
return globalDisable || localDisable
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) GetReqBodySchema() (string, *openapi3.SchemaRef) {
|
||||
if op.RequestBody == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var contentTypeArray = []string{
|
||||
MediaTypeJson,
|
||||
MediaTypeProblemJson,
|
||||
MediaTypeFormURLEncoded,
|
||||
MediaTypeXYaml,
|
||||
MediaTypeYaml,
|
||||
}
|
||||
|
||||
for _, ct := range contentTypeArray {
|
||||
mType := op.RequestBody.Value.Content[ct]
|
||||
if mType == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return ct, mType.Schema
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@ -25,7 +25,6 @@ type ExecuteToolOption struct {
|
||||
Operation *Openapi3Operation
|
||||
InvalidRespProcessStrategy InvalidResponseProcessStrategy
|
||||
|
||||
AgentID int64
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
@ -69,9 +68,8 @@ func WithAutoGenRespSchema() ExecuteToolOpt {
|
||||
}
|
||||
}
|
||||
|
||||
func WithPluginHTTPHeader(agentID, conversationID int64) ExecuteToolOpt {
|
||||
func WithPluginHTTPHeader(conversationID int64) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.AgentID = agentID
|
||||
o.ConversationID = conversationID
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,7 +111,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *
|
||||
}
|
||||
|
||||
if conversationData.CreatorID != userID {
|
||||
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg","user not match"))
|
||||
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "user not match"))
|
||||
}
|
||||
|
||||
return conversationData, nil
|
||||
|
||||
@ -570,8 +570,6 @@ func TestOpenapiAgentRun_ParseAdditionalMessages_InvalidRole(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "additional message role only support user and assistant")
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestOpenapiAgentRun_ParseAdditionalMessages_InvalidType(t *testing.T) {
|
||||
app, _, _, _, mockConversation, mockSingleAgent, _ := setupMocks(t)
|
||||
ctx := createTestContext()
|
||||
|
||||
@ -74,7 +74,6 @@ func newPluginTools(ctx context.Context, conf *toolConfig) ([]tool.InvokableTool
|
||||
projectInfo: projectInfo,
|
||||
toolInfo: ti,
|
||||
|
||||
agentID: conf.agentIdentity.AgentID,
|
||||
conversationID: conf.conversationID,
|
||||
})
|
||||
}
|
||||
@ -88,7 +87,6 @@ type pluginInvokableTool struct {
|
||||
toolInfo *pluginEntity.ToolInfo
|
||||
projectInfo *plugin.ProjectInfo
|
||||
|
||||
agentID int64
|
||||
conversationID int64
|
||||
}
|
||||
|
||||
@ -132,7 +130,7 @@ func (p *pluginInvokableTool) InvokableRun(ctx context.Context, argumentsInJSON
|
||||
plugin.WithInvalidRespProcessStrategy(plugin.InvalidResponseProcessStrategyOfReturnDefault),
|
||||
plugin.WithToolVersion(p.toolInfo.GetVersion()),
|
||||
plugin.WithProjectInfo(p.projectInfo),
|
||||
plugin.WithPluginHTTPHeader(p.agentID, p.conversationID),
|
||||
plugin.WithPluginHTTPHeader(p.conversationID),
|
||||
}
|
||||
|
||||
resp, err := crossplugin.DefaultSVC().ExecuteTool(ctx, req, opts...)
|
||||
|
||||
@ -503,9 +503,8 @@ func TestBatchCreate(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
t.Run("success_single_message", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
@ -140,6 +140,7 @@ type ProjectInfo = model.ProjectInfo
|
||||
|
||||
type PluginManifest = model.PluginManifest
|
||||
|
||||
// TODO API.DESC 来给不同 default 值
|
||||
func NewDefaultPluginManifest() *PluginManifest {
|
||||
return &model.PluginManifest{
|
||||
SchemaVersion: "v1",
|
||||
|
||||
@ -21,47 +21,42 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"io"
|
||||
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/encoder"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service/tool"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *ExecuteToolRequest, opts ...entity.ExecuteToolOpt) (resp *ExecuteToolResponse, err error) {
|
||||
execOpt := &model.ExecuteToolOption{}
|
||||
for _, opt := range opts {
|
||||
opt(execOpt)
|
||||
opt := &model.ExecuteToolOption{}
|
||||
for _, fn := range opts {
|
||||
fn(opt)
|
||||
}
|
||||
|
||||
executor, err := p.buildToolExecutor(ctx, req, execOpt)
|
||||
executor, err := p.buildToolExecutor(ctx, req, opt)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "buildToolExecutor failed")
|
||||
}
|
||||
|
||||
result, err := executor.execute(ctx, req.ArgumentsInJson)
|
||||
authInfo := executor.plugin.GetAuthInfo()
|
||||
accessToken, authURL, err := p.acquireAccessTokenIfNeed(ctx, req, authInfo, executor.tool.Operation)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "acquireAccessToken failed")
|
||||
}
|
||||
|
||||
result, err := executor.execute(ctx, req.ArgumentsInJson, accessToken, authURL)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "execute tool failed")
|
||||
}
|
||||
@ -77,7 +72,7 @@ func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *ExecuteToolReq
|
||||
}
|
||||
|
||||
var respSchema openapi3.Responses
|
||||
if execOpt.AutoGenRespSchema {
|
||||
if opt.AutoGenRespSchema {
|
||||
respSchema, err = p.genToolResponseSchema(ctx, result.RawResp)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "genToolResponseSchema failed")
|
||||
@ -95,9 +90,49 @@ func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *ExecuteToolReq
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) buildToolExecutor(ctx context.Context, req *ExecuteToolRequest,
|
||||
execOpt *model.ExecuteToolOption) (impl *toolExecutor, err error) {
|
||||
func (p *pluginServiceImpl) acquireAccessTokenIfNeed(ctx context.Context, req *ExecuteToolRequest, authInfo *model.AuthV2,
|
||||
schema *model.Openapi3Operation) (accessToken string, authURL string, err error) {
|
||||
if authInfo.Type == model.AuthzTypeOfNone {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
authMode := model.ToolAuthModeOfRequired
|
||||
if tmp, ok := schema.Extensions[model.APISchemaExtendAuthMode].(string); ok {
|
||||
authMode = model.ToolAuthMode(tmp)
|
||||
}
|
||||
|
||||
if authMode == model.ToolAuthModeOfDisabled {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
if authInfo.SubType == model.AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
authorizationCode := &entity.AuthorizationCodeInfo{
|
||||
Meta: &entity.AuthorizationCodeMeta{
|
||||
UserID: req.UserID,
|
||||
PluginID: req.PluginID,
|
||||
IsDraft: req.ExecScene == model.ExecSceneOfToolDebug,
|
||||
},
|
||||
Config: authInfo.AuthOfOAuthAuthorizationCode,
|
||||
}
|
||||
|
||||
accessToken, err = p.GetAccessToken(ctx, &entity.OAuthInfo{
|
||||
OAuthMode: authInfo.SubType,
|
||||
AuthorizationCode: authorizationCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
authURL, err = genAuthURL(authorizationCode)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, authURL, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) buildToolExecutor(ctx context.Context, req *ExecuteToolRequest, opt *model.ExecuteToolOption) (impl *toolExecutor, err error) {
|
||||
if req.UserID == "" {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KV(errno.PluginMsgKey, "userID is required"))
|
||||
}
|
||||
@ -108,13 +143,13 @@ func (p *pluginServiceImpl) buildToolExecutor(ctx context.Context, req *ExecuteT
|
||||
)
|
||||
switch req.ExecScene {
|
||||
case model.ExecSceneOfOnlineAgent:
|
||||
pl, tl, err = p.getOnlineAgentPluginInfo(ctx, req, execOpt)
|
||||
pl, tl, err = p.getOnlineAgentPluginInfo(ctx, req, opt)
|
||||
case model.ExecSceneOfDraftAgent:
|
||||
pl, tl, err = p.getDraftAgentPluginInfo(ctx, req, execOpt)
|
||||
pl, tl, err = p.getDraftAgentPluginInfo(ctx, req, opt)
|
||||
case model.ExecSceneOfToolDebug:
|
||||
pl, tl, err = p.getToolDebugPluginInfo(ctx, req, execOpt)
|
||||
pl, tl, err = p.getToolDebugPluginInfo(ctx, req, opt)
|
||||
case model.ExecSceneOfWorkflow:
|
||||
pl, tl, err = p.getWorkflowPluginInfo(ctx, req, execOpt)
|
||||
pl, tl, err = p.getWorkflowPluginInfo(ctx, req, opt)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid execute scene '%s'", req.ExecScene)
|
||||
}
|
||||
@ -125,17 +160,16 @@ func (p *pluginServiceImpl) buildToolExecutor(ctx context.Context, req *ExecuteT
|
||||
impl = &toolExecutor{
|
||||
execScene: req.ExecScene,
|
||||
userID: req.UserID,
|
||||
agentID: execOpt.AgentID,
|
||||
conversationID: execOpt.ConversationID,
|
||||
conversationID: opt.ConversationID,
|
||||
plugin: pl,
|
||||
tool: tl,
|
||||
projectInfo: execOpt.ProjectInfo,
|
||||
invalidRespProcessStrategy: execOpt.InvalidRespProcessStrategy,
|
||||
svc: p,
|
||||
projectInfo: opt.ProjectInfo,
|
||||
invalidRespProcessStrategy: opt.InvalidRespProcessStrategy,
|
||||
oss: p.oss,
|
||||
}
|
||||
|
||||
if execOpt.Operation != nil {
|
||||
impl.tool.Operation = execOpt.Operation
|
||||
if opt.Operation != nil {
|
||||
impl.tool.Operation = opt.Operation
|
||||
}
|
||||
|
||||
return impl, nil
|
||||
@ -463,7 +497,6 @@ type ExecuteResponse struct {
|
||||
type toolExecutor struct {
|
||||
execScene model.ExecuteScene
|
||||
userID string
|
||||
agentID int64
|
||||
conversationID int64
|
||||
|
||||
plugin *entity.PluginInfo
|
||||
@ -472,84 +505,66 @@ type toolExecutor struct {
|
||||
projectInfo *entity.ProjectInfo
|
||||
invalidRespProcessStrategy model.InvalidResponseProcessStrategy
|
||||
|
||||
svc *pluginServiceImpl
|
||||
oss storage.Storage
|
||||
}
|
||||
|
||||
func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (resp *ExecuteResponse, err error) {
|
||||
func newToolInvocation(t *toolExecutor) tool.Invocation {
|
||||
switch t.plugin.Manifest.API.Type {
|
||||
case model.PluginTypeOfCloud:
|
||||
return tool.NewHttpCallImpl(t.conversationID)
|
||||
case model.PluginTypeOfMCP:
|
||||
return tool.NewMcpCallImpl()
|
||||
case model.PluginTypeOfCustom:
|
||||
return tool.NewCustomCallImpl()
|
||||
default: // default to http call
|
||||
return tool.NewHttpCallImpl(t.conversationID)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *toolExecutor) execute(ctx context.Context, argumentsInJson, accessToken, authURL string) (resp *ExecuteResponse, err error) {
|
||||
if argumentsInJson == "" {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed,
|
||||
errorx.KV(errno.PluginMsgKey, "argumentsInJson is required"))
|
||||
}
|
||||
|
||||
invocation, err := tool.NewInvocationArgs(ctx, &tool.InvocationArgsBuilder{
|
||||
ArgsInJson: argumentsInJson,
|
||||
ProjectInfo: t.projectInfo,
|
||||
UserID: t.userID,
|
||||
AccessToken: accessToken,
|
||||
AuthURL: authURL,
|
||||
Plugin: t.plugin,
|
||||
Tool: t.tool,
|
||||
PluginManifest: t.plugin.Manifest,
|
||||
ServerURL: t.plugin.GetServerURL(),
|
||||
AuthInfo: &tool.AuthInfo{
|
||||
OAuth: &tool.OAuthInfo{
|
||||
AccessToken: accessToken,
|
||||
AuthURL: authURL,
|
||||
},
|
||||
MetaInfo: t.plugin.GetAuthInfo(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.execScene != model.ExecSceneOfToolDebug { // debug
|
||||
// only assemble file uri to url in debug scene
|
||||
err = invocation.AssembleFileURIToURL(ctx, t.oss)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
toolInvocation := newToolInvocation(t)
|
||||
requestStr, rawResp, err := toolInvocation.Do(ctx, invocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const defaultResp = "{}"
|
||||
|
||||
if argumentsInJson == "" {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KV(errno.PluginMsgKey, "argumentsInJson is required"))
|
||||
}
|
||||
|
||||
args, err := t.preprocessArgumentsInJson(ctx, argumentsInJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := t.buildHTTPRequest(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errMsg, err := t.injectAuthInfo(ctx, httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if errMsg != "" {
|
||||
event := &model.ToolInterruptEvent{
|
||||
Event: model.InterruptEventTypeOfToolNeedOAuth,
|
||||
ToolNeedOAuth: &model.ToolNeedOAuthInterruptEvent{
|
||||
Message: errMsg,
|
||||
},
|
||||
}
|
||||
return nil, einoCompose.NewInterruptAndRerunErr(event)
|
||||
}
|
||||
|
||||
var reqBodyBytes []byte
|
||||
if httpReq.GetBody != nil {
|
||||
reqBody, err := httpReq.GetBody()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reqBody.Close()
|
||||
|
||||
reqBodyBytes, err = io.ReadAll(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
requestStr, err := genRequestString(httpReq, reqBodyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restyReq := t.svc.httpCli.NewRequest()
|
||||
restyReq.Header = httpReq.Header
|
||||
restyReq.Method = httpReq.Method
|
||||
restyReq.URL = httpReq.URL.String()
|
||||
if reqBodyBytes != nil {
|
||||
restyReq.SetBody(reqBodyBytes)
|
||||
}
|
||||
restyReq.SetContext(ctx)
|
||||
|
||||
logs.CtxDebugf(ctx, "[execute] url=%s, header=%s, method=%s, body=%s",
|
||||
restyReq.URL, restyReq.Header, restyReq.Method, restyReq.Body)
|
||||
|
||||
httpResp, err := restyReq.Send()
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey, "http request failed, err=%s", err))
|
||||
}
|
||||
|
||||
logs.CtxDebugf(ctx, "[execute] status=%s, response=%s", httpResp.Status(), httpResp.String())
|
||||
|
||||
if httpResp.StatusCode() != http.StatusOK {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed,
|
||||
errorx.KVf(errno.PluginMsgKey, "http request failed, status=%s\nresp=%s", httpResp.Status(), httpResp.String()))
|
||||
}
|
||||
|
||||
rawResp := string(httpResp.Body())
|
||||
if rawResp == "" {
|
||||
return &ExecuteResponse{
|
||||
Request: requestStr,
|
||||
@ -573,454 +588,6 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
|
||||
}, nil
|
||||
}
|
||||
|
||||
func genRequestString(req *http.Request, body []byte) (string, error) {
|
||||
type Request struct {
|
||||
Path string `json:"path"`
|
||||
Header map[string]string `json:"header"`
|
||||
Query map[string]string `json:"query"`
|
||||
Body *[]byte `json:"body"`
|
||||
}
|
||||
|
||||
req_ := &Request{
|
||||
Path: req.URL.Path,
|
||||
Header: map[string]string{},
|
||||
Query: map[string]string{},
|
||||
}
|
||||
|
||||
if len(req.Header) > 0 {
|
||||
for k, v := range req.Header {
|
||||
req_.Header[k] = v[0]
|
||||
}
|
||||
}
|
||||
if len(req.URL.Query()) > 0 {
|
||||
for k, v := range req.URL.Query() {
|
||||
req_.Query[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
requestStr, err := sonic.MarshalString(req_)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
|
||||
}
|
||||
|
||||
if len(body) > 0 {
|
||||
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return requestStr, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) preprocessArgumentsInJson(ctx context.Context, argumentsInJson string) (args map[string]any, err error) {
|
||||
args, err = t.prepareArguments(ctx, argumentsInJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramRefs := t.tool.Operation.Parameters
|
||||
for _, paramRef := range paramRefs {
|
||||
paramVal := paramRef.Value
|
||||
if paramVal.In == openapi3.ParameterInCookie {
|
||||
continue
|
||||
}
|
||||
|
||||
scVal := paramVal.Schema.Value
|
||||
typ := scVal.Type
|
||||
|
||||
if typ == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of parameter '%s' in '%s' cannot be 'object'", paramVal.In, paramVal.Name)
|
||||
}
|
||||
|
||||
argValue, ok := args[paramVal.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if arr, ok := argValue.([]any); ok {
|
||||
for i, e := range arr {
|
||||
e, err = t.convertURItoURL(ctx, e, scVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr[i] = e
|
||||
}
|
||||
} else {
|
||||
argValue, err = t.convertURItoURL(ctx, argValue, scVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
args[paramVal.Name] = argValue
|
||||
}
|
||||
|
||||
_, bodySchema := t.getReqBodySchema(t.tool.Operation)
|
||||
if bodySchema == nil || bodySchema.Value == nil {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// Body restricted to object type
|
||||
if bodySchema.Value.Type != openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("[preprocessArgumentsInJson] requset body is not object, type=%s",
|
||||
bodySchema.Value.Type)
|
||||
}
|
||||
|
||||
if len(bodySchema.Value.Properties) == 0 {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
for paramName, prop := range bodySchema.Value.Properties {
|
||||
argValue, ok := args[paramName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if arr, ok := argValue.([]any); ok {
|
||||
for i, e := range arr {
|
||||
e, err = t.convertURItoURL(ctx, e, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr[i] = e
|
||||
}
|
||||
} else {
|
||||
argValue, err = t.convertURItoURL(ctx, argValue, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
args[paramName] = argValue
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]any) (httpReq *http.Request, err error) {
|
||||
tool := t.tool
|
||||
rawURL := t.plugin.GetServerURL() + tool.GetSubURL()
|
||||
|
||||
locArgs, err := t.getLocationArguments(ctx, argMaps, tool.Operation.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commonParams := t.plugin.Manifest.CommonParams
|
||||
|
||||
reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL, commonParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyArgs := map[string]any{}
|
||||
for k, v := range argMaps {
|
||||
if _, ok := locArgs.header[k]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := locArgs.path[k]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := locArgs.query[k]; ok {
|
||||
continue
|
||||
}
|
||||
bodyArgs[k] = v
|
||||
}
|
||||
|
||||
commonBody := commonParams[model.ParamInBody]
|
||||
bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs, commonBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commonHeader := commonParams[model.ParamInHeader]
|
||||
header, err := locArgs.buildHTTPRequestHeader(ctx, commonHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logId, _ := ctx.Value(consts.CtxLogIDKey).(string)
|
||||
header.Set("X-Tt-Logid", logId)
|
||||
header.Set("X-Aiplugin-Connector-Identifier", t.userID)
|
||||
header.Set("X-AIPlugin-Bot-ID", conv.Int64ToStr(t.agentID))
|
||||
header.Set("X-AIPlugin-Conversation-ID", conv.Int64ToStr(t.conversationID))
|
||||
|
||||
httpReq.Header = header
|
||||
|
||||
if len(bodyBytes) > 0 {
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) prepareArguments(_ context.Context, argumentsInJson string) (map[string]any, error) {
|
||||
args := map[string]any{}
|
||||
|
||||
decoder := sonic.ConfigDefault.NewDecoder(bytes.NewBufferString(argumentsInJson))
|
||||
decoder.UseNumber()
|
||||
|
||||
// Suppose the output of the large model is of type object
|
||||
input := map[string]any{}
|
||||
err := decoder.Decode(&input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[prepareArguments] unmarshal into map failed, input=%s, err=%v",
|
||||
argumentsInJson, err)
|
||||
}
|
||||
|
||||
for k, v := range input {
|
||||
args[k] = v
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) getLocationArguments(ctx context.Context, args map[string]any, paramRefs []*openapi3.ParameterRef) (*locationArguments, error) {
|
||||
headerArgs := map[string]valueWithSchema{}
|
||||
pathArgs := map[string]valueWithSchema{}
|
||||
queryArgs := map[string]valueWithSchema{}
|
||||
|
||||
for _, paramRef := range paramRefs {
|
||||
paramVal := paramRef.Value
|
||||
if paramVal.In == openapi3.ParameterInCookie {
|
||||
continue
|
||||
}
|
||||
|
||||
scVal := paramVal.Schema.Value
|
||||
typ := scVal.Type
|
||||
if typ == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of '%s' parameter '%s' cannot be 'object'", paramVal.In, paramVal.Name)
|
||||
}
|
||||
|
||||
argValue, ok := args[paramVal.Name]
|
||||
if !ok {
|
||||
var err error
|
||||
argValue, err = t.getDefaultValue(ctx, scVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if argValue == nil {
|
||||
if !paramVal.Required {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("the '%s' parameter '%s' is required", paramVal.In, paramVal.Name)
|
||||
}
|
||||
}
|
||||
|
||||
v := valueWithSchema{
|
||||
argValue: argValue,
|
||||
paramSchema: paramVal,
|
||||
}
|
||||
|
||||
switch paramVal.In {
|
||||
case openapi3.ParameterInQuery:
|
||||
queryArgs[paramVal.Name] = v
|
||||
case openapi3.ParameterInHeader:
|
||||
headerArgs[paramVal.Name] = v
|
||||
case openapi3.ParameterInPath:
|
||||
pathArgs[paramVal.Name] = v
|
||||
}
|
||||
}
|
||||
|
||||
locArgs := &locationArguments{
|
||||
header: headerArgs,
|
||||
path: pathArgs,
|
||||
query: queryArgs,
|
||||
}
|
||||
|
||||
return locArgs, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) convertURItoURL(ctx context.Context, arg any, scVal *openapi3.Schema) (newArg any, err error) {
|
||||
if t.execScene != model.ExecSceneOfToolDebug {
|
||||
return arg, nil
|
||||
}
|
||||
if scVal.Type != openapi3.TypeString {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
at := scVal.Extensions[model.APISchemaExtendAssistType]
|
||||
if at == nil {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
_at, ok := at.(string)
|
||||
if !ok {
|
||||
return arg, nil
|
||||
}
|
||||
if !model.IsValidAPIAssistType(model.APIFileAssistType(_at)) {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
uri, ok := arg.(string)
|
||||
if !ok {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://") {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
newArg, err = t.svc.oss.GetObjectUrl(ctx, uri)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetObjectUrl failed, uri=%s", uri)
|
||||
}
|
||||
|
||||
return newArg, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) getDefaultValue(ctx context.Context, scVal *openapi3.Schema) (any, error) {
|
||||
vn, exist := scVal.Extensions[model.APISchemaExtendVariableRef]
|
||||
if !exist {
|
||||
return scVal.Default, nil
|
||||
}
|
||||
|
||||
vnStr, ok := vn.(string)
|
||||
if !ok {
|
||||
logs.CtxErrorf(ctx, "invalid variable_ref type '%T'", vn)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
variableVal, err := t.getVariableValue(ctx, vnStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return variableVal, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) getVariableValue(ctx context.Context, keyword string) (any, error) {
|
||||
info := t.projectInfo
|
||||
if info == nil {
|
||||
return nil, fmt.Errorf("project info is nil")
|
||||
}
|
||||
|
||||
meta := &variables.UserVariableMeta{
|
||||
BizType: project_memory.VariableConnector_Bot,
|
||||
BizID: strconv.FormatInt(info.ProjectID, 10),
|
||||
Version: ptr.FromOrDefault(info.ProjectVersion, ""),
|
||||
ConnectorUID: t.userID,
|
||||
ConnectorID: info.ConnectorID,
|
||||
}
|
||||
vals, err := crossvariables.DefaultSVC().GetVariableInstance(ctx, meta, []string{keyword})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(vals) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return vals[0].Value, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) injectAuthInfo(_ context.Context, httpReq *http.Request) (errMsg string, error error) {
|
||||
authInfo := t.plugin.GetAuthInfo()
|
||||
if authInfo.Type == model.AuthzTypeOfNone {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if authInfo.Type == model.AuthzTypeOfService {
|
||||
return t.injectServiceAPIToken(httpReq.Context(), httpReq, authInfo)
|
||||
}
|
||||
|
||||
if authInfo.Type == model.AuthzTypeOfOAuth {
|
||||
return t.injectOAuthAccessToken(httpReq.Context(), httpReq, authInfo)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) injectServiceAPIToken(ctx context.Context, httpReq *http.Request, authInfo *model.AuthV2) (errMsg string, err error) {
|
||||
if authInfo.SubType == model.AuthzSubTypeOfServiceAPIToken {
|
||||
authOfAPIToken := authInfo.AuthOfAPIToken
|
||||
if authOfAPIToken == nil {
|
||||
return "", fmt.Errorf("auth of api token is nil")
|
||||
}
|
||||
|
||||
loc := strings.ToLower(string(authOfAPIToken.Location))
|
||||
if loc == openapi3.ParameterInQuery {
|
||||
query := httpReq.URL.Query()
|
||||
if query.Get(authOfAPIToken.Key) == "" {
|
||||
query.Set(authOfAPIToken.Key, authOfAPIToken.ServiceToken)
|
||||
httpReq.URL.RawQuery = query.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
if loc == openapi3.ParameterInHeader {
|
||||
if httpReq.Header.Get(authOfAPIToken.Key) == "" {
|
||||
httpReq.Header.Set(authOfAPIToken.Key, authOfAPIToken.ServiceToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) injectOAuthAccessToken(ctx context.Context, httpReq *http.Request, authInfo *model.AuthV2) (errMsg string, err error) {
|
||||
authMode := model.ToolAuthModeOfRequired
|
||||
if tmp, ok := t.tool.Operation.Extensions[model.APISchemaExtendAuthMode].(string); ok {
|
||||
authMode = model.ToolAuthMode(tmp)
|
||||
}
|
||||
|
||||
if authMode == model.ToolAuthModeOfDisabled {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
|
||||
if authInfo.SubType == model.AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
i := &entity.AuthorizationCodeInfo{
|
||||
Meta: &entity.AuthorizationCodeMeta{
|
||||
UserID: t.userID,
|
||||
PluginID: t.plugin.ID,
|
||||
IsDraft: t.execScene == model.ExecSceneOfToolDebug,
|
||||
},
|
||||
Config: authInfo.AuthOfOAuthAuthorizationCode,
|
||||
}
|
||||
|
||||
accessToken, err = t.svc.GetAccessToken(ctx, &entity.OAuthInfo{
|
||||
OAuthMode: authInfo.SubType,
|
||||
AuthorizationCode: i,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken == "" && authMode != model.ToolAuthModeOfSupported {
|
||||
errMsg = authCodeInvalidTokenErrMsg[i18n.GetLocale(ctx)]
|
||||
if errMsg == "" {
|
||||
errMsg = authCodeInvalidTokenErrMsg[i18n.LocaleEN]
|
||||
}
|
||||
authURL, err := genAuthURL(i)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
errMsg = fmt.Sprintf(errMsg, t.plugin.Manifest.NameForHuman, authURL)
|
||||
|
||||
return errMsg, nil
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var authCodeInvalidTokenErrMsg = map[i18n.Locale]string{
|
||||
i18n.LocaleZH: "%s 插件需要授权使用。授权后即代表你同意与扣子中你所选择的 AI 模型分享数据。请[点击这里](%s)进行授权。",
|
||||
i18n.LocaleEN: "The '%s' plugin requires authorization. By authorizing, you agree to share data with the AI model you selected in Coze. Please [click here](%s) to authorize.",
|
||||
}
|
||||
|
||||
func (t *toolExecutor) processResponse(ctx context.Context, rawResp string) (trimmedResp string, err error) {
|
||||
responses := t.tool.Operation.Responses
|
||||
if len(responses) == 0 {
|
||||
@ -1333,219 +900,3 @@ func (t *toolExecutor) disabledParam(schemaVal *openapi3.Schema) bool {
|
||||
}
|
||||
return globalDisable || localDisable
|
||||
}
|
||||
|
||||
type locationArguments struct {
|
||||
header map[string]valueWithSchema
|
||||
path map[string]valueWithSchema
|
||||
query map[string]valueWithSchema
|
||||
}
|
||||
|
||||
type valueWithSchema struct {
|
||||
argValue any
|
||||
paramSchema *openapi3.Parameter
|
||||
}
|
||||
|
||||
func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string,
|
||||
commonParams map[model.HTTPParamLocation][]*common.CommonParamSchema) (reqURL *url.URL, err error) {
|
||||
|
||||
if len(l.path) > 0 {
|
||||
for k, v := range l.path {
|
||||
vStr, err := encoder.EncodeParameter(v.paramSchema, v.argValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawURL = strings.ReplaceAll(rawURL, "{"+k+"}", vStr)
|
||||
}
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
if len(l.query) > 0 {
|
||||
for k, val := range l.query {
|
||||
switch v := val.argValue.(type) {
|
||||
case []any:
|
||||
for _, _v := range v {
|
||||
query.Add(k, encoder.MustString(_v))
|
||||
}
|
||||
default:
|
||||
query.Add(k, encoder.MustString(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
commonQuery := commonParams[model.ParamInQuery]
|
||||
for _, v := range commonQuery {
|
||||
if _, ok := l.query[v.Name]; ok {
|
||||
continue
|
||||
}
|
||||
query.Add(v.Name, v.Value)
|
||||
}
|
||||
|
||||
encodeQuery := query.Encode()
|
||||
|
||||
reqURL, err = url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(reqURL.RawQuery) > 0 && len(encodeQuery) > 0 {
|
||||
reqURL.RawQuery += "&" + encodeQuery
|
||||
} else if len(encodeQuery) > 0 {
|
||||
reqURL.RawQuery = encodeQuery
|
||||
}
|
||||
|
||||
return reqURL, nil
|
||||
}
|
||||
|
||||
func (l *locationArguments) buildHTTPRequestHeader(_ context.Context, commonHeaders []*common.CommonParamSchema) (http.Header, error) {
|
||||
header := http.Header{}
|
||||
if len(l.header) > 0 {
|
||||
for k, v := range l.header {
|
||||
switch vv := v.argValue.(type) {
|
||||
case []any:
|
||||
for _, _v := range vv {
|
||||
header.Add(k, encoder.MustString(_v))
|
||||
}
|
||||
default:
|
||||
header.Add(k, encoder.MustString(vv))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, h := range commonHeaders {
|
||||
if header.Get(h.Name) != "" {
|
||||
continue
|
||||
}
|
||||
header.Add(h.Name, h.Value)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any,
|
||||
commonBody []*common.CommonParamSchema) (body []byte, contentType string, err error) {
|
||||
|
||||
var bodyMap map[string]any
|
||||
|
||||
contentType, bodySchema := t.getReqBodySchema(op)
|
||||
if bodySchema != nil && len(bodySchema.Value.Properties) > 0 {
|
||||
bodyMap, err = t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
for paramName, prop := range bodySchema.Value.Properties {
|
||||
value, ok := bodyMap[paramName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_value, err := encoder.TryCorrectValueType(paramName, prop, value)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
bodyMap[paramName] = _value
|
||||
}
|
||||
|
||||
body, err = encoder.EncodeBodyWithContentType(contentType, bodyMap)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
commonBody_ := make([]*common.CommonParamSchema, 0, len(commonBody))
|
||||
for _, v := range commonBody {
|
||||
if _, ok := bodyMap[v.Name]; ok {
|
||||
continue
|
||||
}
|
||||
commonBody_ = append(commonBody_, v)
|
||||
}
|
||||
|
||||
for _, v := range commonBody_ {
|
||||
body, err = sjson.SetRawBytes(body, v.Name, []byte(v.Value))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("[buildRequestBody] SetRawBytes failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return body, contentType, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *openapi3.Schema, vals map[string]any) (newVals map[string]any, err error) {
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
newVals = make(map[string]any, len(sc.Properties))
|
||||
|
||||
for paramName, prop := range sc.Properties {
|
||||
paramSchema := prop.Value
|
||||
if paramSchema.Type == openapi3.TypeObject {
|
||||
val := vals[paramName]
|
||||
if val == nil {
|
||||
val = map[string]any{}
|
||||
}
|
||||
|
||||
mapVal, ok := val.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[injectRequestBodyDefaultValue] parameter '%s' is not object", paramName)
|
||||
}
|
||||
|
||||
newMapVal, err := t.injectRequestBodyDefaultValue(ctx, paramSchema, mapVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(newMapVal) > 0 {
|
||||
newVals[paramName] = newMapVal
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if val := vals[paramName]; val != nil {
|
||||
newVals[paramName] = val
|
||||
continue
|
||||
}
|
||||
|
||||
defaultVal, err := t.getDefaultValue(ctx, paramSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if defaultVal == nil {
|
||||
if !required[paramName] {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("[injectRequestBodyDefaultValue] parameter '%s' is required", paramName)
|
||||
}
|
||||
|
||||
newVals[paramName] = defaultVal
|
||||
}
|
||||
|
||||
return newVals, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) getReqBodySchema(op *model.Openapi3Operation) (string, *openapi3.SchemaRef) {
|
||||
if op.RequestBody == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var contentTypeArray = []string{
|
||||
model.MediaTypeJson,
|
||||
model.MediaTypeProblemJson,
|
||||
model.MediaTypeFormURLEncoded,
|
||||
model.MediaTypeXYaml,
|
||||
model.MediaTypeYaml,
|
||||
}
|
||||
|
||||
for _, ct := range contentTypeArray {
|
||||
mType := op.RequestBody.Value.Content[ct]
|
||||
if mType == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return ct, mType.Schema
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@ -27,9 +27,8 @@ import (
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
)
|
||||
|
||||
func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnDefault(t *testing.T) {
|
||||
|
||||
@ -19,7 +19,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
@ -44,7 +43,6 @@ func NewService(components *Components) PluginService {
|
||||
pluginRepo: components.PluginRepo,
|
||||
toolRepo: components.ToolRepo,
|
||||
oauthRepo: components.OAuthRepo,
|
||||
httpCli: resty.New(),
|
||||
}
|
||||
|
||||
initOnce.Do(func() {
|
||||
@ -63,5 +61,4 @@ type pluginServiceImpl struct {
|
||||
pluginRepo repository.PluginRepository
|
||||
toolRepo repository.ToolRepository
|
||||
oauthRepo repository.OAuthRepository
|
||||
httpCli *resty.Client
|
||||
}
|
||||
|
||||
25
backend/domain/plugin/service/tool/invocation.go
Normal file
25
backend/domain/plugin/service/tool/invocation.go
Normal file
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Invocation interface {
|
||||
Do(ctx context.Context, args *InvocationArgs) (request string, resp string, err error)
|
||||
}
|
||||
510
backend/domain/plugin/service/tool/invocation_args.go
Normal file
510
backend/domain/plugin/service/tool/invocation_args.go
Normal file
@ -0,0 +1,510 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package tool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type groupedKeys struct {
|
||||
HeaderKeys map[string]*openapi3.Parameter
|
||||
PathKeys map[string]*openapi3.Parameter
|
||||
QueryKeys map[string]*openapi3.Parameter
|
||||
CookieKeys map[string]*openapi3.Parameter
|
||||
BodyKeys map[string]*openapi3.Schema
|
||||
FileKeys map[string]bool
|
||||
}
|
||||
|
||||
type OAuthInfo struct {
|
||||
AccessToken string
|
||||
AuthURL string
|
||||
}
|
||||
type AuthInfo struct {
|
||||
OAuth *OAuthInfo
|
||||
MetaInfo *model.AuthV2
|
||||
}
|
||||
|
||||
type InvocationArgs struct {
|
||||
groupedKeySchema groupedKeys
|
||||
Tool *entity.ToolInfo
|
||||
AuthInfo *AuthInfo
|
||||
PluginManifest *model.PluginManifest
|
||||
ServerURL string
|
||||
|
||||
UserID string
|
||||
ProjectInfo *entity.ProjectInfo
|
||||
AccessToken string
|
||||
AuthURL string
|
||||
|
||||
Header map[string]any
|
||||
Path map[string]any
|
||||
Query map[string]any
|
||||
Cookie map[string]any
|
||||
Body map[string]any
|
||||
}
|
||||
|
||||
type InvocationArgsBuilder struct {
|
||||
ArgsInJson string
|
||||
ProjectInfo *entity.ProjectInfo
|
||||
UserID string
|
||||
AccessToken string
|
||||
AuthURL string
|
||||
Plugin *entity.PluginInfo
|
||||
Tool *entity.ToolInfo
|
||||
AuthInfo *AuthInfo
|
||||
PluginManifest *model.PluginManifest
|
||||
ServerURL string
|
||||
}
|
||||
|
||||
func NewInvocationArgs(ctx context.Context, builder *InvocationArgsBuilder) (*InvocationArgs, error) {
|
||||
// json to map[string]any
|
||||
requestArgs, err := json2Map(builder.ArgsInJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
args := &InvocationArgs{
|
||||
UserID: builder.UserID,
|
||||
ProjectInfo: builder.ProjectInfo,
|
||||
AccessToken: builder.AccessToken,
|
||||
AuthURL: builder.AuthURL,
|
||||
Tool: builder.Tool,
|
||||
AuthInfo: builder.AuthInfo,
|
||||
PluginManifest: builder.PluginManifest,
|
||||
ServerURL: builder.ServerURL,
|
||||
}
|
||||
|
||||
// groupedKeySchema has all key
|
||||
// groupedKey = requestArgs.key + commonParams.key + defaultValues.key
|
||||
args.groupedKeySchema = groupedKeysByLocation(ctx, args.Tool.Operation)
|
||||
// group request args by location
|
||||
args.groupedRequestArgs(ctx, requestArgs)
|
||||
// add common params to each location
|
||||
args.setCommonParams(ctx, args.PluginManifest.CommonParams)
|
||||
// add default values if not exist
|
||||
err = args.setDefaultValues(ctx, builder.ProjectInfo, builder.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func json2Map(argumentsInJson string) (map[string]any, error) {
|
||||
decoder := sonic.ConfigDefault.NewDecoder(bytes.NewBufferString(argumentsInJson))
|
||||
decoder.UseNumber()
|
||||
|
||||
// Suppose the output of the large model is of type object
|
||||
args := map[string]any{}
|
||||
err := decoder.Decode(&args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal into map failed, input=%s, err=%v", argumentsInJson, err)
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func groupedKeysByLocation(ctx context.Context, apiSchema *model.Openapi3Operation) groupedKeys {
|
||||
headerArgs := map[string]*openapi3.Parameter{}
|
||||
pathArgs := map[string]*openapi3.Parameter{}
|
||||
queryArgs := map[string]*openapi3.Parameter{}
|
||||
cookieArgs := map[string]*openapi3.Parameter{}
|
||||
bodyArgs := map[string]*openapi3.Schema{}
|
||||
fileKey := map[string]bool{}
|
||||
|
||||
paramRefs := apiSchema.Parameters
|
||||
for _, paramRef := range paramRefs {
|
||||
valueSchema := paramRef.Value
|
||||
|
||||
if isFileSchema(valueSchema.Schema.Value) {
|
||||
fileKey[valueSchema.Name] = true
|
||||
}
|
||||
|
||||
switch valueSchema.In {
|
||||
case openapi3.ParameterInQuery:
|
||||
queryArgs[valueSchema.Name] = valueSchema
|
||||
case openapi3.ParameterInHeader:
|
||||
headerArgs[valueSchema.Name] = valueSchema
|
||||
case openapi3.ParameterInPath:
|
||||
pathArgs[valueSchema.Name] = valueSchema
|
||||
case openapi3.ParameterInCookie:
|
||||
cookieArgs[valueSchema.Name] = valueSchema
|
||||
default:
|
||||
logs.CtxWarnf(ctx, "[groupedKeysByLocation] unsupported parameter location '%s' in api schema, name=%s", valueSchema.In, valueSchema.Name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
_, bodySchema := apiSchema.GetReqBodySchema()
|
||||
|
||||
if bodySchema != nil && bodySchema.Value != nil {
|
||||
for paramName, paramSchema := range bodySchema.Value.Properties {
|
||||
if isFileSchema(paramSchema.Value) {
|
||||
fileKey[paramName] = true
|
||||
}
|
||||
|
||||
bodyArgs[paramName] = paramSchema.Value
|
||||
}
|
||||
}
|
||||
|
||||
return groupedKeys{
|
||||
HeaderKeys: headerArgs,
|
||||
PathKeys: pathArgs,
|
||||
QueryKeys: queryArgs,
|
||||
CookieKeys: cookieArgs,
|
||||
BodyKeys: bodyArgs,
|
||||
FileKeys: fileKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *InvocationArgs) groupedRequestArgs(ctx context.Context, args map[string]any) {
|
||||
groupedKeySchema := i.groupedKeySchema
|
||||
headerArgs := map[string]any{}
|
||||
pathArgs := map[string]any{}
|
||||
queryArgs := map[string]any{}
|
||||
cookieArgs := map[string]any{}
|
||||
bodyArgs := map[string]any{}
|
||||
|
||||
for k, v := range args {
|
||||
if _, ok := groupedKeySchema.HeaderKeys[k]; ok {
|
||||
headerArgs[k] = v
|
||||
} else if _, ok := groupedKeySchema.PathKeys[k]; ok {
|
||||
pathArgs[k] = v
|
||||
} else if _, ok := groupedKeySchema.QueryKeys[k]; ok {
|
||||
queryArgs[k] = v
|
||||
} else if _, ok := groupedKeySchema.CookieKeys[k]; ok {
|
||||
cookieArgs[k] = v
|
||||
} else if _, ok := groupedKeySchema.BodyKeys[k]; ok {
|
||||
bodyArgs[k] = v
|
||||
} else {
|
||||
logs.CtxWarnf(ctx, "[groupedRequestArgs] unsupported parameter key '%s' in api schema", k)
|
||||
}
|
||||
}
|
||||
|
||||
i.Header = headerArgs
|
||||
i.Path = pathArgs
|
||||
i.Query = queryArgs
|
||||
i.Cookie = cookieArgs
|
||||
i.Body = bodyArgs
|
||||
}
|
||||
|
||||
func (i *InvocationArgs) setCommonParams(ctx context.Context, commonParams map[model.HTTPParamLocation][]*api.CommonParamSchema) {
|
||||
for location, params := range commonParams {
|
||||
for _, param := range params {
|
||||
if param.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var dic map[string]any
|
||||
switch location {
|
||||
case model.ParamInHeader:
|
||||
dic = i.Header
|
||||
case model.ParamInPath:
|
||||
dic = i.Path
|
||||
case model.ParamInQuery:
|
||||
dic = i.Query
|
||||
case model.ParamInBody:
|
||||
dic = i.Body
|
||||
default:
|
||||
logs.CtxWarnf(ctx, "unsupported common parameter location '%s' in api schema, name=%s", location, param.Name)
|
||||
}
|
||||
|
||||
_, ok := dic[param.Name]
|
||||
if !ok {
|
||||
dic[param.Name] = param.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *InvocationArgs) setDefaultValues(ctx context.Context, projectInfo *entity.ProjectInfo, userID string) (err error) {
|
||||
groupedKeysSchema := i.groupedKeySchema
|
||||
|
||||
i.Header, err = setParameterDefaultValues(ctx, i.Header, groupedKeysSchema.HeaderKeys, projectInfo, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.Path, err = setParameterDefaultValues(ctx, i.Path, groupedKeysSchema.PathKeys, projectInfo, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.Query, err = setParameterDefaultValues(ctx, i.Query, groupedKeysSchema.QueryKeys, projectInfo, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.Cookie, err = setParameterDefaultValues(ctx, i.Cookie, groupedKeysSchema.CookieKeys, projectInfo, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, bodySchema := i.Tool.Operation.GetReqBodySchema()
|
||||
i.Body, err = setBodyDefaultValues(ctx, i.Body, bodySchema.Value, projectInfo, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setParameterDefaultValues(ctx context.Context, dic map[string]any, paramSchema map[string]*openapi3.Parameter, projectInfo *entity.ProjectInfo, userID string) (map[string]any, error) {
|
||||
for key, valueSchema := range paramSchema {
|
||||
if valueSchema.Schema == nil || valueSchema.Schema.Value == nil {
|
||||
logs.CtxWarnf(ctx, "[setParameterDefaultValues] parameter '%s' schema is nil", key)
|
||||
continue
|
||||
}
|
||||
|
||||
if valueSchema.Schema.Value.Type == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of '%s' parameter '%s' cannot be 'object'", valueSchema.In, key)
|
||||
}
|
||||
|
||||
if _, ok := dic[key]; !ok {
|
||||
defaultVal, err := getDefaultValue(ctx, valueSchema.Schema.Value, projectInfo, userID)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "get default value failed, key=%s, err=%v", key, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if valueSchema.Required && defaultVal == nil {
|
||||
return nil, fmt.Errorf("the '%s' parameter '%s' is required", valueSchema.In, key)
|
||||
}
|
||||
|
||||
dic[key] = defaultVal
|
||||
}
|
||||
}
|
||||
|
||||
return dic, nil
|
||||
}
|
||||
|
||||
func setBodyDefaultValues(ctx context.Context, dic map[string]any, sc *openapi3.Schema, projectInfo *entity.ProjectInfo, userID string) (map[string]any, error) {
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
newVals := make(map[string]any, len(sc.Properties))
|
||||
|
||||
for paramName, prop := range sc.Properties {
|
||||
paramSchema := prop.Value
|
||||
if paramSchema.Type == openapi3.TypeObject {
|
||||
val := dic[paramName]
|
||||
if val == nil {
|
||||
val = map[string]any{}
|
||||
}
|
||||
|
||||
mapVal, ok := val.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[injectRequestBodyDefaultValue] parameter '%s' is not object", paramName)
|
||||
}
|
||||
|
||||
newMapVal, err := setBodyDefaultValues(ctx, mapVal, paramSchema, projectInfo, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(newMapVal) > 0 {
|
||||
newVals[paramName] = newMapVal
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if val := dic[paramName]; val != nil {
|
||||
newVals[paramName] = val
|
||||
continue
|
||||
}
|
||||
|
||||
defaultVal, err := getDefaultValue(ctx, paramSchema, projectInfo, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if defaultVal == nil {
|
||||
if !required[paramName] {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("[setBodyDefaultValues] parameter '%s' is required", paramName)
|
||||
}
|
||||
|
||||
newVals[paramName] = defaultVal
|
||||
}
|
||||
|
||||
return newVals, nil
|
||||
}
|
||||
|
||||
func getDefaultValue(ctx context.Context, schema *openapi3.Schema, info *entity.ProjectInfo, userID string) (any, error) {
|
||||
vn, exist := schema.Extensions[model.APISchemaExtendVariableRef]
|
||||
if !exist {
|
||||
return schema.Default, nil
|
||||
}
|
||||
|
||||
keyword, ok := vn.(string)
|
||||
if !ok {
|
||||
logs.CtxErrorf(ctx, "invalid variable_ref type '%T'", vn)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
return nil, fmt.Errorf("project info is nil")
|
||||
}
|
||||
|
||||
meta := &variables.UserVariableMeta{
|
||||
BizType: project_memory.VariableConnector(info.ProjectType),
|
||||
BizID: strconv.FormatInt(info.ProjectID, 10),
|
||||
Version: ptr.FromOrDefault(info.ProjectVersion, ""),
|
||||
ConnectorUID: userID,
|
||||
ConnectorID: info.ConnectorID,
|
||||
}
|
||||
|
||||
vals, err := crossvariables.DefaultSVC().GetVariableInstance(ctx, meta, []string{keyword})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(vals) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return vals[0].Value, nil
|
||||
|
||||
}
|
||||
|
||||
func (i *InvocationArgs) AssembleFileURIToURL(ctx context.Context, oss storage.Storage) error {
|
||||
allFileKeys := i.groupedKeySchema.FileKeys
|
||||
for key := range allFileKeys {
|
||||
dic, ok := i.lookupArgGroup(key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
uriObj, ok := dic[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var uris []string
|
||||
if str, ok := uriObj.(string); ok {
|
||||
url, err := convertURItoURL(ctx, str, oss)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dic[key] = url
|
||||
|
||||
} else if arr, ok := uriObj.([]any); ok {
|
||||
for _, item := range arr {
|
||||
if str, ok := item.(string); ok {
|
||||
url, err := convertURItoURL(ctx, str, oss)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uris = append(uris, url)
|
||||
}
|
||||
}
|
||||
if len(uris) > 0 {
|
||||
dic[key] = uris
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *InvocationArgs) lookupArgGroup(key string) (map[string]any, bool) {
|
||||
if _, ok := i.Header[key]; ok {
|
||||
return i.Header, ok
|
||||
}
|
||||
if _, ok := i.Path[key]; ok {
|
||||
return i.Path, ok
|
||||
}
|
||||
|
||||
if _, ok := i.Query[key]; ok {
|
||||
return i.Query, ok
|
||||
}
|
||||
|
||||
if _, ok := i.Cookie[key]; ok {
|
||||
return i.Cookie, ok
|
||||
}
|
||||
|
||||
if _, ok := i.Body[key]; ok {
|
||||
return i.Body, ok
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func convertURItoURL(ctx context.Context, uri string, oss storage.Storage) (newArg string, err error) {
|
||||
if uri == "" {
|
||||
return "", fmt.Errorf("uri is empty")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://") {
|
||||
return uri, nil
|
||||
}
|
||||
|
||||
newArg, err = oss.GetObjectUrl(ctx, uri)
|
||||
if err != nil {
|
||||
return "", errorx.Wrapf(err, "GetObjectUrl failed, uri=%s", uri)
|
||||
}
|
||||
|
||||
return newArg, nil
|
||||
}
|
||||
|
||||
func isFileSchema(valueSchema *openapi3.Schema) bool {
|
||||
if valueSchema.Type != openapi3.TypeString {
|
||||
// file value must be string
|
||||
return false
|
||||
}
|
||||
|
||||
// file schema x-assist-type must not nil
|
||||
assistTypeObj := valueSchema.Extensions[model.APISchemaExtendAssistType]
|
||||
if assistTypeObj == nil {
|
||||
// it is not a file value
|
||||
return false
|
||||
}
|
||||
|
||||
assistType, ok := assistTypeObj.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !model.IsValidAPIAssistType(model.APIFileAssistType(assistType)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
49
backend/domain/plugin/service/tool/invocation_custom_call.go
Normal file
49
backend/domain/plugin/service/tool/invocation_custom_call.go
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var customToolMap = make(map[string]Invocation)
|
||||
|
||||
func RegisterCustomTool(toolID string, t Invocation) error {
|
||||
if _, ok := customToolMap[toolID]; ok {
|
||||
return fmt.Errorf("custom tool path %s already registered", toolID)
|
||||
}
|
||||
|
||||
customToolMap[toolID] = t
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvokableRun(ctx context.Context, argumentsInJSON string, opts ...Option) (string, error)
|
||||
type customCallImpl struct{}
|
||||
|
||||
func NewCustomCallImpl() Invocation {
|
||||
return &customCallImpl{}
|
||||
}
|
||||
|
||||
func (c *customCallImpl) Do(ctx context.Context, args *InvocationArgs) (request string, resp string, err error) {
|
||||
if t, ok := customToolMap[fmt.Sprintf("%d", args.Tool.ID)]; ok {
|
||||
return t.Do(ctx, args)
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("custom tool not found")
|
||||
}
|
||||
369
backend/domain/plugin/service/tool/invocation_http.go
Normal file
369
backend/domain/plugin/service/tool/invocation_http.go
Normal file
@ -0,0 +1,369 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package tool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/encoder"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type httpCallImpl struct {
|
||||
ConversationID int64
|
||||
}
|
||||
|
||||
var defaultHttpCli *resty.Client = resty.New()
|
||||
|
||||
func NewHttpCallImpl(ConversationID int64) Invocation {
|
||||
return &httpCallImpl{
|
||||
ConversationID: ConversationID,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) Do(ctx context.Context, args *InvocationArgs) (request string, resp string, err error) {
|
||||
httpReq, err := h.buildHTTPRequest(ctx, args)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
errMsg, err := h.injectAuthInfo(ctx, httpReq, args)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if errMsg != "" {
|
||||
event := &model.ToolInterruptEvent{
|
||||
Event: model.InterruptEventTypeOfToolNeedOAuth,
|
||||
ToolNeedOAuth: &model.ToolNeedOAuthInterruptEvent{
|
||||
Message: errMsg,
|
||||
},
|
||||
}
|
||||
|
||||
return "", "", compose.NewInterruptAndRerunErr(event)
|
||||
}
|
||||
|
||||
var reqBodyBytes []byte
|
||||
if httpReq.GetBody != nil {
|
||||
reqBody, err := httpReq.GetBody()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer reqBody.Close()
|
||||
|
||||
reqBodyBytes, err = io.ReadAll(reqBody)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
requestStr, err := genRequestString(httpReq, reqBodyBytes)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
restyReq := defaultHttpCli.NewRequest()
|
||||
restyReq.Header = httpReq.Header
|
||||
restyReq.Method = httpReq.Method
|
||||
restyReq.URL = httpReq.URL.String()
|
||||
if reqBodyBytes != nil {
|
||||
restyReq.SetBody(reqBodyBytes)
|
||||
}
|
||||
restyReq.SetContext(ctx)
|
||||
|
||||
logs.CtxDebugf(ctx, "[execute] url=%s, header=%s, method=%s, body=%s",
|
||||
restyReq.URL, restyReq.Header, restyReq.Method, restyReq.Body)
|
||||
|
||||
httpResp, err := restyReq.Send()
|
||||
if err != nil {
|
||||
return "", "", errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey, "http request failed, err=%s", err))
|
||||
}
|
||||
|
||||
logs.CtxDebugf(ctx, "[execute] status=%s, response=%s", httpResp.Status(), httpResp.String())
|
||||
|
||||
if httpResp.StatusCode() != http.StatusOK {
|
||||
return "", "", errorx.New(errno.ErrPluginExecuteToolFailed,
|
||||
errorx.KVf(errno.PluginMsgKey, "http request failed, status=%s\nresp=%s", httpResp.Status(), httpResp.String()))
|
||||
}
|
||||
|
||||
return requestStr, httpResp.String(), nil
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) buildHTTPRequest(ctx context.Context, args *InvocationArgs) (httpReq *http.Request, err error) {
|
||||
tool := args.Tool
|
||||
rawURL := args.ServerURL + tool.GetSubURL()
|
||||
|
||||
reqURL, err := h.buildHTTPRequestURL(ctx, rawURL, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyBytes, contentType, err := h.buildRequestBody(ctx, tool.Operation, args.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header, err = h.buildHTTPRequestHeader(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(bodyBytes) > 0 {
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) injectAuthInfo(ctx context.Context, httpReq *http.Request, args *InvocationArgs) (errMsg string, err error) {
|
||||
|
||||
if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfNone {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfService {
|
||||
return h.injectServiceAPIToken(ctx, httpReq, args.AuthInfo.MetaInfo)
|
||||
}
|
||||
|
||||
if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfOAuth {
|
||||
return h.injectOAuthAccessToken(ctx, httpReq, args)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func genRequestString(req *http.Request, body []byte) (string, error) {
|
||||
type Request struct {
|
||||
Path string `json:"path"`
|
||||
Header map[string]string `json:"header"`
|
||||
Query map[string]string `json:"query"`
|
||||
Body *[]byte `json:"body"`
|
||||
}
|
||||
|
||||
req_ := &Request{
|
||||
Path: req.URL.Path,
|
||||
Header: map[string]string{},
|
||||
Query: map[string]string{},
|
||||
}
|
||||
|
||||
if len(req.Header) > 0 {
|
||||
for k, v := range req.Header {
|
||||
req_.Header[k] = v[0]
|
||||
}
|
||||
}
|
||||
if len(req.URL.Query()) > 0 {
|
||||
for k, v := range req.URL.Query() {
|
||||
req_.Query[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
requestStr, err := sonic.MarshalString(req_)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
|
||||
}
|
||||
|
||||
if len(body) > 0 {
|
||||
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return requestStr, nil
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) buildHTTPRequestURL(ctx context.Context, rawURL string, args *InvocationArgs) (reqURL *url.URL, err error) {
|
||||
if len(args.Path) > 0 {
|
||||
for k, v := range args.Path {
|
||||
p := args.groupedKeySchema.PathKeys[k]
|
||||
vStr, eErr := encoder.EncodeParameter(p, v)
|
||||
if eErr != nil {
|
||||
return nil, eErr
|
||||
}
|
||||
rawURL = strings.ReplaceAll(rawURL, "{"+k+"}", vStr)
|
||||
}
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
if len(args.Query) > 0 {
|
||||
for k, val := range args.Query {
|
||||
switch v := val.(type) {
|
||||
case []any:
|
||||
for _, _v := range v {
|
||||
query.Add(k, encoder.MustString(_v))
|
||||
}
|
||||
default:
|
||||
query.Add(k, encoder.MustString(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
encodeQuery := query.Encode()
|
||||
|
||||
reqURL, err = url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(reqURL.RawQuery) > 0 && len(encodeQuery) > 0 {
|
||||
reqURL.RawQuery += "&" + encodeQuery
|
||||
} else if len(encodeQuery) > 0 {
|
||||
reqURL.RawQuery = encodeQuery
|
||||
}
|
||||
|
||||
return reqURL, nil
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any) (body []byte, contentType string, err error) {
|
||||
contentType, bodySchema := op.GetReqBodySchema()
|
||||
if bodySchema != nil && len(bodySchema.Value.Properties) > 0 {
|
||||
for paramName, prop := range bodySchema.Value.Properties {
|
||||
value, ok := bodyArgs[paramName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_value, eErr := encoder.TryCorrectValueType(paramName, prop, value)
|
||||
if eErr != nil {
|
||||
return nil, "", eErr
|
||||
}
|
||||
|
||||
bodyArgs[paramName] = _value
|
||||
}
|
||||
|
||||
body, err = encoder.EncodeBodyWithContentType(contentType, bodyArgs)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return body, contentType, nil
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) injectServiceAPIToken(ctx context.Context, httpReq *http.Request, authInfo *model.AuthV2) (errMsg string, err error) {
|
||||
if authInfo.SubType == model.AuthzSubTypeOfServiceAPIToken {
|
||||
authOfAPIToken := authInfo.AuthOfAPIToken
|
||||
if authOfAPIToken == nil {
|
||||
return "", fmt.Errorf("auth of api token is nil")
|
||||
}
|
||||
|
||||
loc := strings.ToLower(string(authOfAPIToken.Location))
|
||||
if loc == openapi3.ParameterInQuery {
|
||||
query := httpReq.URL.Query()
|
||||
if query.Get(authOfAPIToken.Key) == "" {
|
||||
query.Set(authOfAPIToken.Key, authOfAPIToken.ServiceToken)
|
||||
httpReq.URL.RawQuery = query.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
if loc == openapi3.ParameterInHeader {
|
||||
if httpReq.Header.Get(authOfAPIToken.Key) == "" {
|
||||
httpReq.Header.Set(authOfAPIToken.Key, authOfAPIToken.ServiceToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) injectOAuthAccessToken(ctx context.Context, httpReq *http.Request, args *InvocationArgs) (errMsg string, err error) {
|
||||
authMode := model.ToolAuthModeOfRequired
|
||||
if tmp, ok := args.Tool.Operation.Extensions[model.APISchemaExtendAuthMode].(string); ok {
|
||||
authMode = model.ToolAuthMode(tmp)
|
||||
}
|
||||
|
||||
if authMode == model.ToolAuthModeOfDisabled {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
accessToken := args.AccessToken
|
||||
authInfo := args.AuthInfo.MetaInfo
|
||||
|
||||
if authInfo.SubType == model.AuthzSubTypeOfOAuthAuthorizationCode &&
|
||||
accessToken == "" && authMode != model.ToolAuthModeOfSupported {
|
||||
errMsg = authCodeInvalidTokenErrMsg[i18n.GetLocale(ctx)]
|
||||
if errMsg == "" {
|
||||
errMsg = authCodeInvalidTokenErrMsg[i18n.LocaleEN]
|
||||
}
|
||||
|
||||
errMsg = fmt.Sprintf(errMsg, args.PluginManifest.NameForHuman, args.AuthURL)
|
||||
|
||||
return errMsg, nil
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var authCodeInvalidTokenErrMsg = map[i18n.Locale]string{
|
||||
i18n.LocaleZH: "%s 插件需要授权使用。授权后即代表你同意与扣子中你所选择的 AI 模型分享数据。请[点击这里](%s)进行授权。",
|
||||
i18n.LocaleEN: "The '%s' plugin requires authorization. By authorizing, you agree to share data with the AI model you selected in Coze. Please [click here](%s) to authorize.",
|
||||
}
|
||||
|
||||
func (h *httpCallImpl) buildHTTPRequestHeader(ctx context.Context, args *InvocationArgs) (http.Header, error) {
|
||||
header := http.Header{}
|
||||
if len(args.Header) > 0 {
|
||||
for k, v := range args.Header {
|
||||
switch vv := v.(type) {
|
||||
case []any:
|
||||
for _, _v := range vv {
|
||||
header.Add(k, encoder.MustString(_v))
|
||||
}
|
||||
default:
|
||||
header.Add(k, encoder.MustString(vv))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logId, _ := ctx.Value(consts.CtxLogIDKey).(string)
|
||||
header.Set("X-Tt-Logid", logId)
|
||||
header.Set("X-Aiplugin-Connector-Identifier", args.UserID)
|
||||
if args.ProjectInfo != nil {
|
||||
header.Set("X-AIPlugin-Bot-ID", conv.Int64ToStr(args.ProjectInfo.ProjectID))
|
||||
}
|
||||
if h.ConversationID > 0 {
|
||||
header.Set("X-AIPlugin-Conversation-ID", conv.Int64ToStr(h.ConversationID))
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
33
backend/domain/plugin/service/tool/invocation_mcp.go
Normal file
33
backend/domain/plugin/service/tool/invocation_mcp.go
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type mcpCallImpl struct{}
|
||||
|
||||
func NewMcpCallImpl() Invocation {
|
||||
return &mcpCallImpl{}
|
||||
}
|
||||
|
||||
func (m *mcpCallImpl) Do(ctx context.Context, args *InvocationArgs) (request string, resp string, err error) {
|
||||
// only for tool debug scene
|
||||
return "", "", errors.New("mcp call not implemented")
|
||||
}
|
||||
@ -19,11 +19,12 @@ package compose
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -36,6 +35,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
|
||||
@ -310,7 +310,6 @@ require (
|
||||
github.com/mtibben/percent v0.2.1 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
|
||||
golang.org/x/term v0.32.0 // indirect
|
||||
|
||||
@ -19,10 +19,11 @@ package logs
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
var logger FullLogger = &defaultLogger{
|
||||
|
||||
Reference in New Issue
Block a user