Files
coze-studio/backend/domain/plugin/service/tool/invocation_http.go

370 lines
10 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tool
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/getkin/kin-openapi/openapi3"
"github.com/go-resty/resty/v2"
"github.com/tidwall/sjson"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/encoder"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type httpCallImpl struct {
ConversationID int64
}
var defaultHttpCli *resty.Client = resty.New()
func NewHttpCallImpl(ConversationID int64) Invocation {
return &httpCallImpl{
ConversationID: ConversationID,
}
}
func (h *httpCallImpl) Do(ctx context.Context, args *InvocationArgs) (request string, resp string, err error) {
httpReq, err := h.buildHTTPRequest(ctx, args)
if err != nil {
return "", "", err
}
errMsg, err := h.injectAuthInfo(ctx, httpReq, args)
if err != nil {
return "", "", err
}
if errMsg != "" {
event := &model.ToolInterruptEvent{
Event: model.InterruptEventTypeOfToolNeedOAuth,
ToolNeedOAuth: &model.ToolNeedOAuthInterruptEvent{
Message: errMsg,
},
}
return "", "", compose.NewInterruptAndRerunErr(event)
}
var reqBodyBytes []byte
if httpReq.GetBody != nil {
reqBody, err := httpReq.GetBody()
if err != nil {
return "", "", err
}
defer reqBody.Close()
reqBodyBytes, err = io.ReadAll(reqBody)
if err != nil {
return "", "", err
}
}
requestStr, err := genRequestString(httpReq, reqBodyBytes)
if err != nil {
return "", "", err
}
restyReq := defaultHttpCli.NewRequest()
restyReq.Header = httpReq.Header
restyReq.Method = httpReq.Method
restyReq.URL = httpReq.URL.String()
if reqBodyBytes != nil {
restyReq.SetBody(reqBodyBytes)
}
restyReq.SetContext(ctx)
logs.CtxDebugf(ctx, "[execute] url=%s, header=%s, method=%s, body=%s",
restyReq.URL, restyReq.Header, restyReq.Method, restyReq.Body)
httpResp, err := restyReq.Send()
if err != nil {
return "", "", errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey, "http request failed, err=%s", err))
}
logs.CtxDebugf(ctx, "[execute] status=%s, response=%s", httpResp.Status(), httpResp.String())
if httpResp.StatusCode() != http.StatusOK {
return "", "", errorx.New(errno.ErrPluginExecuteToolFailed,
errorx.KVf(errno.PluginMsgKey, "http request failed, status=%s\nresp=%s", httpResp.Status(), httpResp.String()))
}
return requestStr, httpResp.String(), nil
}
func (h *httpCallImpl) buildHTTPRequest(ctx context.Context, args *InvocationArgs) (httpReq *http.Request, err error) {
tool := args.Tool
rawURL := args.ServerURL + tool.GetSubURL()
reqURL, err := h.buildHTTPRequestURL(ctx, rawURL, args)
if err != nil {
return nil, err
}
bodyBytes, contentType, err := h.buildRequestBody(ctx, tool.Operation, args.Body)
if err != nil {
return nil, err
}
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, err
}
httpReq.Header, err = h.buildHTTPRequestHeader(ctx, args)
if err != nil {
return nil, err
}
if len(bodyBytes) > 0 {
httpReq.Header.Set("Content-Type", contentType)
}
return httpReq, nil
}
func (h *httpCallImpl) injectAuthInfo(ctx context.Context, httpReq *http.Request, args *InvocationArgs) (errMsg string, err error) {
if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfNone {
return "", nil
}
if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfService {
return h.injectServiceAPIToken(ctx, httpReq, args.AuthInfo.MetaInfo)
}
if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfOAuth {
return h.injectOAuthAccessToken(ctx, httpReq, args)
}
return "", nil
}
func genRequestString(req *http.Request, body []byte) (string, error) {
type Request struct {
Path string `json:"path"`
Header map[string]string `json:"header"`
Query map[string]string `json:"query"`
Body *[]byte `json:"body"`
}
req_ := &Request{
Path: req.URL.Path,
Header: map[string]string{},
Query: map[string]string{},
}
if len(req.Header) > 0 {
for k, v := range req.Header {
req_.Header[k] = v[0]
}
}
if len(req.URL.Query()) > 0 {
for k, v := range req.URL.Query() {
req_.Query[k] = v[0]
}
}
requestStr, err := sonic.MarshalString(req_)
if err != nil {
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
}
if len(body) > 0 {
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
if err != nil {
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
}
}
return requestStr, nil
}
func (h *httpCallImpl) buildHTTPRequestURL(ctx context.Context, rawURL string, args *InvocationArgs) (reqURL *url.URL, err error) {
if len(args.Path) > 0 {
for k, v := range args.Path {
p := args.groupedKeySchema.PathKeys[k]
vStr, eErr := encoder.EncodeParameter(p, v)
if eErr != nil {
return nil, eErr
}
rawURL = strings.ReplaceAll(rawURL, "{"+k+"}", vStr)
}
}
query := url.Values{}
if len(args.Query) > 0 {
for k, val := range args.Query {
switch v := val.(type) {
case []any:
for _, _v := range v {
query.Add(k, encoder.MustString(_v))
}
default:
query.Add(k, encoder.MustString(v))
}
}
}
encodeQuery := query.Encode()
reqURL, err = url.Parse(rawURL)
if err != nil {
return nil, err
}
if len(reqURL.RawQuery) > 0 && len(encodeQuery) > 0 {
reqURL.RawQuery += "&" + encodeQuery
} else if len(encodeQuery) > 0 {
reqURL.RawQuery = encodeQuery
}
return reqURL, nil
}
func (h *httpCallImpl) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any) (body []byte, contentType string, err error) {
contentType, bodySchema := op.GetReqBodySchema()
if bodySchema != nil && len(bodySchema.Value.Properties) > 0 {
for paramName, prop := range bodySchema.Value.Properties {
value, ok := bodyArgs[paramName]
if !ok {
continue
}
_value, eErr := encoder.TryCorrectValueType(paramName, prop, value)
if eErr != nil {
return nil, "", eErr
}
bodyArgs[paramName] = _value
}
body, err = encoder.EncodeBodyWithContentType(contentType, bodyArgs)
if err != nil {
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
}
}
return body, contentType, nil
}
func (h *httpCallImpl) injectServiceAPIToken(ctx context.Context, httpReq *http.Request, authInfo *model.AuthV2) (errMsg string, err error) {
if authInfo.SubType == model.AuthzSubTypeOfServiceAPIToken {
authOfAPIToken := authInfo.AuthOfAPIToken
if authOfAPIToken == nil {
return "", fmt.Errorf("auth of api token is nil")
}
loc := strings.ToLower(string(authOfAPIToken.Location))
if loc == openapi3.ParameterInQuery {
query := httpReq.URL.Query()
if query.Get(authOfAPIToken.Key) == "" {
query.Set(authOfAPIToken.Key, authOfAPIToken.ServiceToken)
httpReq.URL.RawQuery = query.Encode()
}
}
if loc == openapi3.ParameterInHeader {
if httpReq.Header.Get(authOfAPIToken.Key) == "" {
httpReq.Header.Set(authOfAPIToken.Key, authOfAPIToken.ServiceToken)
}
}
}
return "", nil
}
func (h *httpCallImpl) injectOAuthAccessToken(ctx context.Context, httpReq *http.Request, args *InvocationArgs) (errMsg string, err error) {
authMode := model.ToolAuthModeOfRequired
if tmp, ok := args.Tool.Operation.Extensions[model.APISchemaExtendAuthMode].(string); ok {
authMode = model.ToolAuthMode(tmp)
}
if authMode == model.ToolAuthModeOfDisabled {
return "", nil
}
accessToken := args.AccessToken
authInfo := args.AuthInfo.MetaInfo
if authInfo.SubType == model.AuthzSubTypeOfOAuthAuthorizationCode &&
accessToken == "" && authMode != model.ToolAuthModeOfSupported {
errMsg = authCodeInvalidTokenErrMsg[i18n.GetLocale(ctx)]
if errMsg == "" {
errMsg = authCodeInvalidTokenErrMsg[i18n.LocaleEN]
}
errMsg = fmt.Sprintf(errMsg, args.PluginManifest.NameForHuman, args.AuthURL)
return errMsg, nil
}
if accessToken != "" {
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
}
return "", nil
}
var authCodeInvalidTokenErrMsg = map[i18n.Locale]string{
i18n.LocaleZH: "%s 插件需要授权使用。授权后即代表你同意与扣子中你所选择的 AI 模型分享数据。请[点击这里](%s)进行授权。",
i18n.LocaleEN: "The '%s' plugin requires authorization. By authorizing, you agree to share data with the AI model you selected in Coze. Please [click here](%s) to authorize.",
}
func (h *httpCallImpl) buildHTTPRequestHeader(ctx context.Context, args *InvocationArgs) (http.Header, error) {
header := http.Header{}
if len(args.Header) > 0 {
for k, v := range args.Header {
switch vv := v.(type) {
case []any:
for _, _v := range vv {
header.Add(k, encoder.MustString(_v))
}
default:
header.Add(k, encoder.MustString(vv))
}
}
}
logId, _ := ctx.Value(consts.CtxLogIDKey).(string)
header.Set("X-Tt-Logid", logId)
header.Set("X-Aiplugin-Connector-Identifier", args.UserID)
if args.ProjectInfo != nil {
header.Set("X-AIPlugin-Bot-ID", conv.Int64ToStr(args.ProjectInfo.ProjectID))
}
if h.ConversationID > 0 {
header.Set("X-AIPlugin-Conversation-ID", conv.Int64ToStr(h.ConversationID))
}
return header, nil
}