diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index 805da1eec..2c70ed79b 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -51,6 +51,7 @@ import ( message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message" "github.com/coze-dev/coze-studio/backend/domain/workflow/config" + "github.com/coze-dev/coze-studio/backend/domain/workflow/plugin" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge" modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge" @@ -79,7 +80,7 @@ import ( crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr" mockmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr/modelmock" crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - pluginmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock" crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user" "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code" @@ -333,6 +334,10 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { mockPluginSrv := pluginmock.NewMockPluginService(ctrl) crossplugin.SetDefaultSVC(mockPluginSrv) + mockStorage := storageMock.NewMockStorage(ctrl) + mockStorage.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any()).Return("URL_ADDRESS", nil).AnyTimes() + plugin.SetOSS(mockStorage) + mockConversation := conversationmock.NewMockConversation(ctrl) crossconversation.SetDefaultSVC(mockConversation) mockMessage := messagemock.NewMockMessage(ctrl) @@ -2912,6 +2917,7 @@ func TestInputComplex(t *testing.T) { } func TestLLMWithSkills(t *testing.T) { + mockey.PatchConvey("workflow llm node with plugin", t, func() { r := newWfTestRunner(t) defer r.closeFn() @@ -3116,16 +3122,16 @@ func TestLLMWithSkills(t *testing.T) { } r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(utChatModel, nil, nil).AnyTimes() - t.Run("llm with workflow tool", func(t *testing.T) { - r.load("llm_node_with_skills/llm_workflow_as_tool.json", withID(7509120431183544356), withPublish("v0.0.1")) - id := r.load("llm_node_with_skills/llm_node_with_workflow_tool.json") - exeID := r.testRun(id, map[string]string{ - "input_string": "ok_input_string", - }) - e := r.getProcess(id, exeID) - e.assertSuccess() - assert.Equal(t, `{"output":"output_data"}`, e.output) - }) + // t.Run("llm with workflow tool", func(t *testing.T) { + // r.load("llm_node_with_skills/llm_workflow_as_tool.json", withID(7509120431183544356), withPublish("v0.0.1")) + // id := r.load("llm_node_with_skills/llm_node_with_workflow_tool.json") + // exeID := r.testRun(id, map[string]string{ + // "input_string": "ok_input_string", + // }) + // e := r.getProcess(id, exeID) + // e.assertSuccess() + // assert.Equal(t, `{"output":"output_data"}`, e.output) + // }) }) mockey.PatchConvey("workflow llm node with knowledge skill", t, func() { diff --git a/backend/api/model/crossdomain/singleagent/single_agent.go b/backend/api/model/crossdomain/singleagent/single_agent.go index 8f363e2ef..ad156776a 100644 --- a/backend/api/model/crossdomain/singleagent/single_agent.go +++ b/backend/api/model/crossdomain/singleagent/single_agent.go @@ -22,7 +22,7 @@ import ( "github.com/coze-dev/coze-studio/backend/api/model/app/bot_common" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun" - plugindto "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow" ) @@ -95,7 +95,7 @@ const ( ) type InterruptInfo struct { - AllToolInterruptData map[string]*plugindto.ToolInterruptEvent + AllToolInterruptData map[string]*model.ToolInterruptEvent AllWfInterruptData map[string]*crossworkflow.ToolInterruptEvent ToolCallID string InterruptType InterruptEventType diff --git a/backend/api/model/resource/common/resource_common.go b/backend/api/model/resource/common/resource_common.go index 9d87dbaf5..e82b3656d 100644 --- a/backend/api/model/resource/common/resource_common.go +++ b/backend/api/model/resource/common/resource_common.go @@ -6,6 +6,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "github.com/apache/thrift/lib/go/thrift" ) diff --git a/backend/application/app/app.go b/backend/application/app/app.go index afe8feef7..830b24c74 100644 --- a/backend/application/app/app.go +++ b/backend/application/app/app.go @@ -45,7 +45,7 @@ import ( "github.com/coze-dev/coze-studio/backend/application/memory" "github.com/coze-dev/coze-studio/backend/application/plugin" "github.com/coze-dev/coze-studio/backend/application/workflow" - pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginConsts "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/domain/app/entity" "github.com/coze-dev/coze-studio/backend/domain/app/repository" "github.com/coze-dev/coze-studio/backend/domain/app/service" @@ -792,16 +792,16 @@ func pluginCopyDispatchHandler(ctx context.Context, metaInfo *copyMetaInfo, res } func copyPlugin(ctx context.Context, metaInfo *copyMetaInfo, res *entity.Resource) (resp *dto.CopyPluginResponse, err error) { - var copyScene pluginModel.CopyScene + var copyScene pluginConsts.CopyScene switch metaInfo.scene { case resourceCommon.ResourceCopyScene_CopyProjectResource: - copyScene = pluginModel.CopySceneOfDuplicate + copyScene = pluginConsts.CopySceneOfDuplicate case resourceCommon.ResourceCopyScene_CopyResourceToLibrary: - copyScene = pluginModel.CopySceneOfToLibrary + copyScene = pluginConsts.CopySceneOfToLibrary case resourceCommon.ResourceCopyScene_CopyResourceFromLibrary: - copyScene = pluginModel.CopySceneOfToAPP + copyScene = pluginConsts.CopySceneOfToAPP case resourceCommon.ResourceCopyScene_CopyProject: - copyScene = pluginModel.CopySceneOfAPPDuplicate + copyScene = pluginConsts.CopySceneOfAPPDuplicate default: return nil, fmt.Errorf("unsupported copy scene '%s'", metaInfo.scene) } diff --git a/backend/application/plugin/api_management.go b/backend/application/plugin/api_management.go new file mode 100644 index 000000000..08473f7c5 --- /dev/null +++ b/backend/application/plugin/api_management.go @@ -0,0 +1,473 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plugin + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/bytedance/sonic" + "github.com/getkin/kin-openapi/openapi3" + gonanoid "github.com/matoous/go-nanoid" + + pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" + "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert/api" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" + searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" + "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +func (p *PluginApplicationService) GetPluginAPIs(ctx context.Context, req *pluginAPI.GetPluginAPIsRequest) (resp *pluginAPI.GetPluginAPIsResponse, err error) { + pl, err := p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateGetPluginAPIsRequest failed") + } + + var ( + draftTools []*entity.ToolInfo + total int64 + ) + if len(req.APIIds) > 0 { + toolIDs := make([]int64, 0, len(req.APIIds)) + for _, id := range req.APIIds { + toolID, pErr := strconv.ParseInt(id, 10, 64) + if pErr != nil { + return nil, fmt.Errorf("invalid tool id '%s'", id) + } + toolIDs = append(toolIDs, toolID) + } + + draftTools, err = p.toolRepo.MGetDraftTools(ctx, toolIDs) + if err != nil { + return nil, errorx.Wrapf(err, "MGetDraftTools failed, toolIDs=%v", toolIDs) + } + + total = int64(len(draftTools)) + + } else { + pageInfo := dto.PageInfo{ + Page: int(req.Page), + Size: int(req.Size), + SortBy: ptr.Of(dto.SortByCreatedAt), + OrderByACS: ptr.Of(false), + } + draftTools, total, err = p.toolRepo.ListPluginDraftTools(ctx, req.PluginID, pageInfo) + if err != nil { + return nil, errorx.Wrapf(err, "ListPluginDraftTools failed, pluginID=%d", req.PluginID) + } + } + + if len(draftTools) == 0 { + return &pluginAPI.GetPluginAPIsResponse{ + APIInfo: make([]*common.PluginAPIInfo, 0), + Total: 0, + }, nil + } + + draftToolIDs := slices.Transform(draftTools, func(tl *entity.ToolInfo) int64 { + return tl.ID + }) + onlineStatus, err := p.getToolOnlineStatus(ctx, draftToolIDs) + if err != nil { + return nil, err + } + + apis := make([]*common.PluginAPIInfo, 0, len(draftTools)) + for _, tool := range draftTools { + method, ok := convert.ToThriftAPIMethod(tool.GetMethod()) + if !ok { + return nil, fmt.Errorf("invalid method '%s'", tool.GetMethod()) + } + reqParams, err := tool.ToReqAPIParameter() + if err != nil { + return nil, err + } + respParams, err := tool.ToRespAPIParameter() + if err != nil { + return nil, err + } + + var apiExtend *common.APIExtend + if tmp, ok := tool.Operation.Extensions[consts.APISchemaExtendAuthMode].(string); ok { + if mode, ok := convert.ToThriftAPIAuthMode(consts.ToolAuthMode(tmp)); ok { + apiExtend = &common.APIExtend{ + AuthMode: mode, + } + } + } + + api := &common.PluginAPIInfo{ + APIID: strconv.FormatInt(tool.ID, 10), + CreateTime: strconv.FormatInt(tool.CreatedAt/1000, 10), + DebugStatus: tool.GetDebugStatus(), + Desc: tool.GetDesc(), + Disabled: func() bool { + return tool.IsDeactivated() + }(), + Method: method, + Name: tool.GetName(), + OnlineStatus: onlineStatus[tool.ID], + Path: tool.GetSubURL(), + PluginID: strconv.FormatInt(tool.PluginID, 10), + RequestParams: reqParams, + ResponseParams: respParams, + StatisticData: common.NewPluginStatisticData(), + APIExtend: apiExtend, + } + example := pl.GetToolExample(ctx, tool.GetName()) + if example != nil { + api.DebugExample = &common.DebugExample{ + ReqExample: example.RequestExample, + RespExample: example.ResponseExample, + } + api.DebugExampleStatus = common.DebugExampleStatus_Enable + } + + apis = append(apis, api) + } + + resp = &pluginAPI.GetPluginAPIsResponse{ + APIInfo: apis, + Total: int32(total), + } + + return resp, nil +} + +func (p *PluginApplicationService) CreateAPI(ctx context.Context, req *pluginAPI.CreateAPIRequest) (resp *pluginAPI.CreateAPIResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateCreateAPIRequest failed") + } + + defaultSubURL := gonanoid.MustID(6) + + tool := &entity.ToolInfo{ + PluginID: req.PluginID, + ActivatedStatus: ptr.Of(consts.ActivateTool), + DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting), + SubURL: ptr.Of("/" + defaultSubURL), + Method: ptr.Of(http.MethodGet), + Operation: model.NewOpenapi3Operation(&openapi3.Operation{ + Summary: req.Desc, + OperationID: req.Name, + Parameters: []*openapi3.ParameterRef{}, + RequestBody: model.DefaultOpenapi3RequestBody(), + Responses: model.DefaultOpenapi3Responses(), + Extensions: map[string]any{}, + }), + } + + toolID, err := p.toolRepo.CreateDraftTool(ctx, tool) + if err != nil { + return nil, errorx.Wrapf(err, "CreateDraftTool failed, pluginID=%d", req.PluginID) + } + + resp = &pluginAPI.CreateAPIResponse{ + APIID: strconv.FormatInt(toolID, 10), + } + + return resp, nil +} + +func (p *PluginApplicationService) UpdateAPI(ctx context.Context, req *pluginAPI.UpdateAPIRequest) (resp *pluginAPI.UpdateAPIResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateUpdateAPIRequest failed") + } + + op, err := api.APIParamsToOpenapiOperation(req.RequestParams, req.ResponseParams) + if err != nil { + return nil, err + } + + var method *string + if m, ok := convert.ToHTTPMethod(req.GetMethod()); ok { + method = &m + } + + updateReq := &dto.UpdateDraftToolRequest{ + PluginID: req.PluginID, + ToolID: req.APIID, + Name: req.Name, + Desc: req.Desc, + SubURL: req.Path, + Method: method, + Parameters: op.Parameters, + RequestBody: op.RequestBody, + Responses: op.Responses, + Disabled: req.Disabled, + SaveExample: req.SaveExample, + DebugExample: req.DebugExample, + APIExtend: req.APIExtend, + } + err = p.DomainSVC.UpdateDraftTool(ctx, updateReq) + if err != nil { + return nil, errorx.Wrapf(err, "UpdateDraftTool failed, pluginID=%d, toolID=%d", updateReq.PluginID, updateReq.ToolID) + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Updated, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: req.PluginID, + UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) + } + + resp = &pluginAPI.UpdateAPIResponse{} + + return resp, nil +} + +func (p *PluginApplicationService) DeleteAPI(ctx context.Context, req *pluginAPI.DeleteAPIRequest) (resp *pluginAPI.DeleteAPIResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateDeleteAPIRequest failed") + } + + err = p.toolRepo.DeleteDraftTool(ctx, req.APIID) + if err != nil { + return nil, errorx.Wrapf(err, "DeleteDraftTool failed, toolID=%d", req.APIID) + } + + resp = &pluginAPI.DeleteAPIResponse{} + + return resp, nil +} + +func (p *PluginApplicationService) BatchCreateAPI(ctx context.Context, req *pluginAPI.BatchCreateAPIRequest) (resp *pluginAPI.BatchCreateAPIResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateBatchCreateAPIRequest failed") + } + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(req.Openapi)) + if err != nil { + return nil, errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error())) + } + + res, err := p.DomainSVC.CreateDraftToolsWithCode(ctx, &dto.CreateDraftToolsWithCodeRequest{ + PluginID: req.PluginID, + OpenapiDoc: ptr.Of(model.Openapi3T(*doc)), + ConflictAndUpdate: req.ReplaceSamePaths, + }) + if err != nil { + return nil, errorx.Wrapf(err, "CreateDraftToolsWithCode failed, pluginID=%d", req.PluginID) + } + + duplicated := slices.Transform(res.DuplicatedTools, func(e dto.UniqueToolAPI) *common.PluginAPIInfo { + method, _ := convert.ToThriftAPIMethod(e.Method) + return &common.PluginAPIInfo{ + Path: e.SubURL, + Method: method, + } + }) + + resp = &pluginAPI.BatchCreateAPIResponse{ + PathsDuplicated: duplicated, + } + + if len(duplicated) > 0 { + resp.Code = errno.ErrPluginDuplicatedTool + } + + return resp, nil +} + +func (p *PluginApplicationService) GetUpdatedAPIs(ctx context.Context, req *pluginAPI.GetUpdatedAPIsRequest) (resp *pluginAPI.GetUpdatedAPIsResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateGetUpdatedAPIsRequest failed") + } + + draftTools, err := p.toolRepo.GetPluginAllDraftTools(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", req.PluginID) + } + onlineTools, err := p.toolRepo.GetPluginAllOnlineTools(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "GetPluginAllOnlineTools failed, pluginID=%d", req.PluginID) + } + + var updatedToolName, createdToolName, delToolName []string + + draftMap := slices.ToMap(draftTools, func(e *entity.ToolInfo) (string, *entity.ToolInfo) { + return e.GetName(), e + }) + onlineMap := slices.ToMap(onlineTools, func(e *entity.ToolInfo) (string, *entity.ToolInfo) { + return e.GetName(), e + }) + + for name := range draftMap { + if _, ok := onlineMap[name]; !ok { + createdToolName = append(createdToolName, name) + } + } + + for name, ot := range onlineMap { + dt, ok := draftMap[name] + if !ok { + delToolName = append(delToolName, name) + continue + } + + if ot.GetMethod() != dt.GetMethod() || + ot.GetSubURL() != dt.GetSubURL() || + ot.GetDesc() != dt.GetDesc() { + updatedToolName = append(updatedToolName, name) + continue + } + + os, err := sonic.MarshalString(ot.Operation) + if err != nil { + logs.CtxErrorf(ctx, "marshal online tool operation failed, toolID=%d, err=%v", ot.ID, err) + + updatedToolName = append(updatedToolName, name) + continue + } + ds, err := sonic.MarshalString(dt.Operation) + if err != nil { + logs.CtxErrorf(ctx, "marshal draft tool operation failed, toolID=%d, err=%v", ot.ID, err) + + updatedToolName = append(updatedToolName, name) + continue + } + + if os != ds { + updatedToolName = append(updatedToolName, name) + } + } + + resp = &pluginAPI.GetUpdatedAPIsResponse{ + UpdatedAPINames: updatedToolName, + CreatedAPINames: createdToolName, + DeletedAPINames: delToolName, + } + + return resp, nil +} + +func (p *PluginApplicationService) DebugAPI(ctx context.Context, req *pluginAPI.DebugAPIRequest) (resp *pluginAPI.DebugAPIResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateDebugAPIRequest failed") + } + + const defaultErrReason = "internal server error" + + userID := ctxutil.GetUIDFromCtx(ctx) + if userID == nil { + return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) + } + + resp = &pluginAPI.DebugAPIResponse{ + Success: false, + RawReq: "{}", + RawResp: "{}", + Resp: "{}", + } + + opts := []model.ExecuteToolOpt{} + switch req.Operation { + case common.DebugOperation_Debug: + opts = append(opts, model.WithInvalidRespProcessStrategy(consts.InvalidResponseProcessStrategyOfReturnErr)) + case common.DebugOperation_Parse: + opts = append(opts, model.WithAutoGenRespSchema(), + model.WithInvalidRespProcessStrategy(consts.InvalidResponseProcessStrategyOfReturnRaw), + ) + } + + res, err := p.DomainSVC.ExecuteTool(ctx, &model.ExecuteToolRequest{ + UserID: conv.Int64ToStr(*userID), + PluginID: req.PluginID, + ToolID: req.APIID, + ExecScene: consts.ExecSceneOfToolDebug, + ExecDraftTool: true, + ArgumentsInJson: req.Parameters, + }, opts...) + if err != nil { + var e errorx.StatusError + if errors.As(err, &e) { + resp.Reason = e.Msg() + return resp, nil + } + + logs.CtxErrorf(ctx, "ExecuteTool failed, err=%v", err) + resp.Reason = defaultErrReason + + return resp, nil + } + + resp = &pluginAPI.DebugAPIResponse{ + Success: true, + Resp: res.TrimmedResp, + RawReq: res.Request, + RawResp: res.RawResp, + ResponseParams: []*common.APIParameter{}, + } + + if req.Operation == common.DebugOperation_Parse { + res.Tool.Operation.Responses = res.RespSchema + } + + respParams, err := res.Tool.ToRespAPIParameter() + if err != nil { + logs.CtxErrorf(ctx, "ToRespAPIParameter failed, err=%v", err) + resp.Success = false + resp.Reason = defaultErrReason + } else { + resp.ResponseParams = respParams + } + + return resp, nil +} + +func (p *PluginApplicationService) getToolOnlineStatus(ctx context.Context, toolIDs []int64) (map[int64]common.OnlineStatus, error) { + onlineTools, err := p.toolRepo.MGetOnlineTools(ctx, toolIDs, repository.WithToolID()) + if err != nil { + return nil, errorx.Wrapf(err, "MGetOnlineTools failed, toolIDs=%v", toolIDs) + } + + onlineStatus := make(map[int64]common.OnlineStatus, len(onlineTools)) + for _, tool := range onlineTools { + onlineStatus[tool.ID] = common.OnlineStatus_ONLINE + } + + return onlineStatus, nil +} diff --git a/backend/application/plugin/auth.go b/backend/application/plugin/auth.go new file mode 100644 index 000000000..38228d48c --- /dev/null +++ b/backend/application/plugin/auth.go @@ -0,0 +1,161 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plugin + +import ( + "context" + "encoding/json" + "net/url" + "os" + + botOpenAPI "github.com/coze-dev/coze-studio/backend/api/model/app/bot_open_api" + pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" + pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +func (p *PluginApplicationService) GetOAuthSchema(ctx context.Context, req *pluginAPI.GetOAuthSchemaRequest) (resp *pluginAPI.GetOAuthSchemaResponse, err error) { + return &pluginAPI.GetOAuthSchemaResponse{ + OauthSchema: pluginConf.GetOAuthSchema(), + }, nil +} + +func (p *PluginApplicationService) GetOAuthStatus(ctx context.Context, req *pluginAPI.GetOAuthStatusRequest) (resp *pluginAPI.GetOAuthStatusResponse, err error) { + userID := ctxutil.GetUIDFromCtx(ctx) + if userID == nil { + return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) + } + + res, err := p.DomainSVC.GetOAuthStatus(ctx, *userID, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "GetOAuthStatus failed, pluginID=%d", req.PluginID) + } + resp = &pluginAPI.GetOAuthStatusResponse{ + IsOauth: res.IsOauth, + Status: res.Status, + Content: res.OAuthURL, + } + + return resp, nil +} + +func (p *PluginApplicationService) OauthAuthorizationCode(ctx context.Context, req *botOpenAPI.OauthAuthorizationCodeReq) (resp *botOpenAPI.OauthAuthorizationCodeResp, err error) { + stateStr, err := url.QueryUnescape(req.State) + if err != nil { + return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) + } + + secret := os.Getenv(encrypt.StateSecretEnv) + if secret == "" { + secret = encrypt.DefaultStateSecret + } + + stateBytes, err := encrypt.DecryptByAES(stateStr, secret) + if err != nil { + return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) + } + + state := &dto.OAuthState{} + err = json.Unmarshal(stateBytes, state) + if err != nil { + return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) + } + + err = p.DomainSVC.OAuthCode(ctx, req.Code, state) + if err != nil { + return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "authorize failed")) + } + + resp = &botOpenAPI.OauthAuthorizationCodeResp{} + + return resp, nil +} + +func (p *PluginApplicationService) GetQueriedOAuthPluginList(ctx context.Context, req *pluginAPI.GetQueriedOAuthPluginListRequest) (resp *pluginAPI.GetQueriedOAuthPluginListResponse, err error) { + userID := ctxutil.GetUIDFromCtx(ctx) + if userID == nil { + return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) + } + + status, err := p.DomainSVC.GetAgentPluginsOAuthStatus(ctx, *userID, req.BotID) + if err != nil { + return nil, errorx.Wrapf(err, "GetAgentPluginsOAuthStatus failed, userID=%d, agentID=%d", *userID, req.BotID) + } + + if len(status) == 0 { + return &pluginAPI.GetQueriedOAuthPluginListResponse{ + OauthPluginList: []*pluginAPI.OAuthPluginInfo{}, + }, nil + } + + oauthPluginList := make([]*pluginAPI.OAuthPluginInfo, 0, len(status)) + for _, s := range status { + oauthPluginList = append(oauthPluginList, &pluginAPI.OAuthPluginInfo{ + PluginID: s.PluginID, + Status: s.Status, + Name: s.PluginName, + PluginIcon: s.PluginIconURL, + }) + } + + resp = &pluginAPI.GetQueriedOAuthPluginListResponse{ + OauthPluginList: oauthPluginList, + } + + return resp, nil +} + +func (p *PluginApplicationService) RevokeAuthToken(ctx context.Context, req *pluginAPI.RevokeAuthTokenRequest) (resp *pluginAPI.RevokeAuthTokenResponse, err error) { + userID := ctxutil.GetUIDFromCtx(ctx) + if userID == nil { + return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) + } + + err = p.DomainSVC.RevokeAccessToken(ctx, &dto.AuthorizationCodeMeta{ + UserID: conv.Int64ToStr(*userID), + PluginID: req.PluginID, + IsDraft: req.GetBotID() == 0, + }) + if err != nil { + return nil, errorx.Wrapf(err, "RevokeAccessToken failed, pluginID=%d", req.PluginID) + } + + resp = &pluginAPI.RevokeAuthTokenResponse{} + + return resp, nil +} + +func (p *PluginApplicationService) GetUserAuthority(ctx context.Context, req *pluginAPI.GetUserAuthorityRequest) (resp *pluginAPI.GetUserAuthorityResponse, err error) { + resp = &pluginAPI.GetUserAuthorityResponse{ + Data: &common.GetUserAuthorityData{ + CanEdit: true, + CanRead: true, + CanDelete: true, + CanDebug: true, + CanPublish: true, + CanReadChangelog: true, + }, + } + + return resp, nil +} diff --git a/backend/application/plugin/info.go b/backend/application/plugin/info.go new file mode 100644 index 000000000..89f000cd2 --- /dev/null +++ b/backend/application/plugin/info.go @@ -0,0 +1,340 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plugin + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/getkin/kin-openapi/openapi3" + "gopkg.in/yaml.v3" + + pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" + "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" + searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +func (p *PluginApplicationService) GetPluginInfo(ctx context.Context, req *pluginAPI.GetPluginInfoRequest) (resp *pluginAPI.GetPluginInfoResponse, err error) { + draftPlugin, err := p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateGetPluginInfoRequest failed") + } + + metaInfo, err := p.getPluginMetaInfo(ctx, draftPlugin) + if err != nil { + return nil, err + } + + codeInfo, err := p.getPluginCodeInfo(ctx, draftPlugin) + if err != nil { + return nil, err + } + + _, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, req.PluginID, repository.WithPluginID()) + if err != nil { + return nil, errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", req.PluginID) + } + + resp = &pluginAPI.GetPluginInfoResponse{ + MetaInfo: metaInfo, + CodeInfo: codeInfo, + Creator: common.NewCreator(), + StatisticData: common.NewPluginStatisticData(), + PluginType: draftPlugin.PluginType, + CreationMethod: common.CreationMethod_COZE, + Published: exist, + } + + return resp, nil +} + +func (p *PluginApplicationService) getPluginCodeInfo(ctx context.Context, draftPlugin *entity.PluginInfo) (*common.CodeInfo, error) { + tools, err := p.toolRepo.GetPluginAllDraftTools(ctx, draftPlugin.ID) + if err != nil { + return nil, errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", draftPlugin.ID) + } + + paths := openapi3.Paths{} + for _, tool := range tools { + if tool.IsDeactivated() { + continue + } + item := &openapi3.PathItem{} + item.SetOperation(tool.GetMethod(), tool.Operation.Operation) + paths[tool.GetSubURL()] = item + } + draftPlugin.OpenapiDoc.Paths = paths + + manifestStr, err := sonic.MarshalString(draftPlugin.Manifest) + if err != nil { + return nil, fmt.Errorf("marshal manifest failed, err=%v", err) + } + + docBytes, err := yaml.Marshal(draftPlugin.OpenapiDoc) + if err != nil { + return nil, fmt.Errorf("marshal openapi doc failed, err=%v", err) + } + + codeInfo := &common.CodeInfo{ + OpenapiDesc: string(docBytes), + PluginDesc: manifestStr, + } + + return codeInfo, nil +} + +func (p *PluginApplicationService) getPluginMetaInfo(ctx context.Context, draftPlugin *entity.PluginInfo) (*common.PluginMetaInfo, error) { + commonParams := make(map[common.ParameterLocation][]*common.CommonParamSchema, len(draftPlugin.Manifest.CommonParams)) + for loc, params := range draftPlugin.Manifest.CommonParams { + location, ok := convert.ToThriftHTTPParamLocation(loc) + if !ok { + return nil, fmt.Errorf("invalid location '%s'", loc) + } + commonParams[location] = make([]*common.CommonParamSchema, 0, len(params)) + for _, param := range params { + commonParams[location] = append(commonParams[location], &common.CommonParamSchema{ + Name: param.Name, + Value: param.Value, + }) + } + } + + iconURL, err := p.oss.GetObjectUrl(ctx, draftPlugin.GetIconURI()) + if err != nil { + logs.CtxWarnf(ctx, "get icon url with '%s' failed, err=%v", draftPlugin.GetIconURI(), err) + } + + metaInfo := &common.PluginMetaInfo{ + Name: draftPlugin.GetName(), + Desc: draftPlugin.GetDesc(), + URL: draftPlugin.GetServerURL(), + Icon: &common.PluginIcon{ + URI: draftPlugin.GetIconURI(), + URL: iconURL, + }, + CommonParams: commonParams, + } + + err = p.fillAuthInfoInMetaInfo(ctx, draftPlugin, metaInfo) + if err != nil { + return nil, errorx.Wrapf(err, "fillAuthInfoInMetaInfo failed, pluginID=%d", draftPlugin.ID) + } + + return metaInfo, nil +} + +func (p *PluginApplicationService) fillAuthInfoInMetaInfo(ctx context.Context, draftPlugin *entity.PluginInfo, metaInfo *common.PluginMetaInfo) (err error) { + authInfo := draftPlugin.GetAuthInfo() + authType, ok := convert.ToThriftAuthType(authInfo.Type) + if !ok { + return fmt.Errorf("invalid auth type '%s'", authInfo.Type) + } + + var subAuthType *int32 + if authInfo.SubType != "" { + _subAuthType, ok := convert.ToThriftAuthSubType(authInfo.SubType) + if !ok { + return fmt.Errorf("invalid sub authz type '%s'", authInfo.SubType) + } + subAuthType = &_subAuthType + } + + metaInfo.AuthType = append(metaInfo.AuthType, authType) + metaInfo.SubAuthType = subAuthType + + if authType == common.AuthorizationType_None { + return nil + } + + if authType == common.AuthorizationType_Service { + var loc common.AuthorizationServiceLocation + _loc := consts.HTTPParamLocation(strings.ToLower(string(authInfo.AuthOfAPIToken.Location))) + if _loc == consts.ParamInHeader { + loc = common.AuthorizationServiceLocation_Header + } else if _loc == consts.ParamInQuery { + loc = common.AuthorizationServiceLocation_Query + } else { + return fmt.Errorf("invalid location '%s'", authInfo.AuthOfAPIToken.Location) + } + + metaInfo.Location = ptr.Of(loc) + metaInfo.Key = ptr.Of(authInfo.AuthOfAPIToken.Key) + metaInfo.ServiceToken = ptr.Of(authInfo.AuthOfAPIToken.ServiceToken) + } + + if authType == common.AuthorizationType_OAuth { + metaInfo.OauthInfo = &authInfo.Payload + } + + return nil +} + +func (p *PluginApplicationService) UpdatePlugin(ctx context.Context, req *pluginAPI.UpdatePluginRequest) (resp *pluginAPI.UpdatePluginResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateUpdatePluginRequest failed") + } + + userID := ctxutil.GetUIDFromCtx(ctx) + + loader := openapi3.NewLoader() + _doc, err := loader.LoadFromData([]byte(req.Openapi)) + if err != nil { + return nil, errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error())) + } + + doc := ptr.Of(model.Openapi3T(*_doc)) + + manifest := &model.PluginManifest{} + err = sonic.UnmarshalString(req.AiPlugin, manifest) + if err != nil { + return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, err.Error())) + } + + err = p.DomainSVC.UpdateDraftPluginWithCode(ctx, &dto.UpdateDraftPluginWithCodeRequest{ + UserID: *userID, + PluginID: req.PluginID, + OpenapiDoc: doc, + Manifest: manifest, + }) + if err != nil { + return nil, errorx.Wrapf(err, "UpdateDraftPluginWithCode failed, pluginID=%d", req.PluginID) + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Updated, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: req.PluginID, + Name: &manifest.NameForHuman, + UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) + } + + resp = &pluginAPI.UpdatePluginResponse{ + Data: &common.UpdatePluginData{ + Res: true, + }, + } + + return resp, nil +} + +func (p *PluginApplicationService) UpdatePluginMeta(ctx context.Context, req *pluginAPI.UpdatePluginMetaRequest) (resp *pluginAPI.UpdatePluginMetaResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateUpdatePluginMetaRequest failed") + } + + authInfo, err := getUpdateAuthInfo(ctx, req) + if err != nil { + return nil, err + } + + updateReq := &dto.UpdateDraftPluginRequest{ + PluginID: req.PluginID, + Name: req.Name, + Desc: req.Desc, + URL: req.URL, + Icon: req.Icon, + CommonParams: req.CommonParams, + AuthInfo: authInfo, + } + err = p.DomainSVC.UpdateDraftPlugin(ctx, updateReq) + if err != nil { + return nil, errorx.Wrapf(err, "UpdateDraftPlugin failed, pluginID=%d", req.PluginID) + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Updated, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: req.PluginID, + Name: req.Name, + UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) + } + + resp = &pluginAPI.UpdatePluginMetaResponse{} + + return resp, nil +} + +func getUpdateAuthInfo(ctx context.Context, req *pluginAPI.UpdatePluginMetaRequest) (authInfo *dto.PluginAuthInfo, err error) { + if req.AuthType == nil { + return nil, nil + } + + _authType, ok := convert.ToAuthType(req.GetAuthType()) + if !ok { + return nil, fmt.Errorf("invalid auth type '%d'", req.GetAuthType()) + } + authType := &_authType + + var authSubType *consts.AuthzSubType + if req.SubAuthType != nil { + _authSubType, ok := convert.ToAuthSubType(req.GetSubAuthType()) + if !ok { + return nil, fmt.Errorf("invalid sub authz type '%d'", req.GetSubAuthType()) + } + authSubType = &_authSubType + } + + var location *consts.HTTPParamLocation + if req.Location != nil { + if *req.Location == common.AuthorizationServiceLocation_Header { + location = ptr.Of(consts.ParamInHeader) + } else if *req.Location == common.AuthorizationServiceLocation_Query { + location = ptr.Of(consts.ParamInQuery) + } else { + return nil, fmt.Errorf("invalid location '%d'", req.GetLocation()) + } + } + + authInfo = &dto.PluginAuthInfo{ + AuthzType: authType, + Location: location, + Key: req.Key, + ServiceToken: req.ServiceToken, + OAuthInfo: req.OauthInfo, + AuthzSubType: authSubType, + AuthzPayload: req.AuthPayload, + } + + return authInfo, nil +} diff --git a/backend/application/plugin/init.go b/backend/application/plugin/init.go index ecf433af3..d2426838d 100644 --- a/backend/application/plugin/init.go +++ b/backend/application/plugin/init.go @@ -24,7 +24,6 @@ import ( "gorm.io/gorm" "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" - pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" "github.com/coze-dev/coze-studio/backend/domain/plugin/service" search "github.com/coze-dev/coze-studio/backend/domain/search/service" @@ -45,7 +44,7 @@ type ServiceComponents struct { } func InitService(ctx context.Context, components *ServiceComponents) (*PluginApplicationService, error) { - err := pluginConf.InitConfig(ctx) + err := conf.InitConfig(ctx) if err != nil { return nil, err } diff --git a/backend/application/plugin/lifecycle.go b/backend/application/plugin/lifecycle.go new file mode 100644 index 000000000..dcd2f065e --- /dev/null +++ b/backend/application/plugin/lifecycle.go @@ -0,0 +1,250 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plugin + +import ( + "context" + "time" + + pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/logs" +) + +func (p *PluginApplicationService) PublishPlugin(ctx context.Context, req *pluginAPI.PublishPluginRequest) (resp *pluginAPI.PublishPluginResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validatePublishPluginRequest failed") + } + + err = p.DomainSVC.PublishPlugin(ctx, &model.PublishPluginRequest{ + PluginID: req.PluginID, + Version: req.VersionName, + VersionDesc: req.VersionDesc, + }) + if err != nil { + return nil, errorx.Wrapf(err, "PublishPlugin failed, pluginID=%d", req.PluginID) + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Updated, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: req.PluginID, + PublishStatus: ptr.Of(resCommon.PublishStatus_Published), + PublishTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) + } + + resp = &pluginAPI.PublishPluginResponse{} + + return resp, nil +} + +func (p *PluginApplicationService) DelPlugin(ctx context.Context, req *pluginAPI.DelPluginRequest) (resp *pluginAPI.DelPluginResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateDelPluginRequest failed") + } + + err = p.DomainSVC.DeleteDraftPlugin(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "DeleteDraftPlugin failed, pluginID=%d", req.PluginID) + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Deleted, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: req.PluginID, + UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + return nil, errorx.Wrapf(err, "publish resource '%d' failed", req.PluginID) + } + + resp = &pluginAPI.DelPluginResponse{} + + return resp, nil +} + +func (p *PluginApplicationService) GetPluginNextVersion(ctx context.Context, req *pluginAPI.GetPluginNextVersionRequest) (resp *pluginAPI.GetPluginNextVersionResponse, err error) { + _, err = p.validateDraftPluginAccess(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "validateGetPluginNextVersionRequest failed") + } + + nextVersion, err := p.DomainSVC.GetPluginNextVersion(ctx, req.PluginID) + if err != nil { + return nil, errorx.Wrapf(err, "GetPluginNextVersion failed, pluginID=%d", req.PluginID) + } + resp = &pluginAPI.GetPluginNextVersionResponse{ + NextVersionName: nextVersion, + } + return resp, nil +} + +func (p *PluginApplicationService) GetDevPluginList(ctx context.Context, req *pluginAPI.GetDevPluginListRequest) (resp *pluginAPI.GetDevPluginListResponse, err error) { + pageInfo := dto.PageInfo{ + Name: req.Name, + Page: int(req.GetPage()), + Size: int(req.GetSize()), + OrderByACS: ptr.Of(false), + } + if req.GetOrderBy() == common.OrderBy_UpdateTime { + pageInfo.SortBy = ptr.Of(dto.SortByUpdatedAt) + } else { + pageInfo.SortBy = ptr.Of(dto.SortByCreatedAt) + } + + res, err := p.DomainSVC.ListDraftPlugins(ctx, &dto.ListDraftPluginsRequest{ + SpaceID: req.SpaceID, + APPID: req.ProjectID, + PageInfo: pageInfo, + }) + if err != nil { + return nil, errorx.Wrapf(err, "ListDraftPlugins failed, spaceID=%d, appID=%d", req.SpaceID, req.ProjectID) + } + + pluginList := make([]*common.PluginInfoForPlayground, 0, len(res.Plugins)) + for _, pl := range res.Plugins { + tools, err := p.toolRepo.GetPluginAllDraftTools(ctx, pl.ID) + if err != nil { + return nil, errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", pl.ID) + } + + pluginInfo, err := p.toPluginInfoForPlayground(ctx, pl, tools) + if err != nil { + return nil, err + } + + pluginInfo.VersionTs = "0" // when you get the plugin information in the project, version ts is set to 0 by default + pluginList = append(pluginList, pluginInfo) + } + + resp = &pluginAPI.GetDevPluginListResponse{ + PluginList: pluginList, + Total: res.Total, + } + + return resp, nil +} + +func (p *PluginApplicationService) DeleteAPPAllPlugins(ctx context.Context, appID int64) (err error) { + pluginIDs, err := p.DomainSVC.DeleteAPPAllPlugins(ctx, appID) + if err != nil { + return errorx.Wrapf(err, "DeleteAPPAllPlugins failed, appID=%d", appID) + } + + for _, id := range pluginIDs { + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Deleted, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: id, + }, + }) + if err != nil { + return errorx.Wrapf(err, "publish resource '%d' failed", id) + } + } + + return nil +} + +func (p *PluginApplicationService) CopyPlugin(ctx context.Context, req *dto.CopyPluginRequest) (resp *dto.CopyPluginResponse, err error) { + res, err := p.DomainSVC.CopyPlugin(ctx, &dto.CopyPluginRequest{ + UserID: req.UserID, + PluginID: req.PluginID, + CopyScene: req.CopyScene, + TargetAPPID: req.TargetAPPID, + }) + if err != nil { + return nil, errorx.Wrapf(err, "CopyPlugin failed, pluginID=%d", req.PluginID) + } + + plugin := res.Plugin + + now := time.Now().UnixMilli() + resDoc := &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResSubType: ptr.Of(int32(plugin.PluginType)), + ResID: plugin.ID, + Name: ptr.Of(plugin.GetName()), + SpaceID: &plugin.SpaceID, + APPID: plugin.APPID, + OwnerID: &req.UserID, + PublishStatus: ptr.Of(resCommon.PublishStatus_UnPublished), + CreateTimeMS: ptr.Of(now), + } + if plugin.Published() { + resDoc.PublishStatus = ptr.Of(resCommon.PublishStatus_Published) + resDoc.PublishTimeMS = ptr.Of(now) + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Created, + Resource: resDoc, + }) + if err != nil { + return nil, errorx.Wrapf(err, "publish resource '%d' failed", plugin.ID) + } + + resp = &dto.CopyPluginResponse{ + Plugin: res.Plugin, + Tools: res.Tools, + } + + return resp, nil +} + +func (p *PluginApplicationService) MoveAPPPluginToLibrary(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error) { + plugin, err = p.DomainSVC.MoveAPPPluginToLibrary(ctx, pluginID) + if err != nil { + return nil, errorx.Wrapf(err, "MoveAPPPluginToLibrary failed, pluginID=%d", pluginID) + } + + now := time.Now().UnixMilli() + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Updated, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResID: pluginID, + APPID: ptr.Of(int64(0)), + PublishStatus: ptr.Of(resCommon.PublishStatus_Published), + PublishTimeMS: ptr.Of(now), + UpdateTimeMS: ptr.Of(now), + }, + }) + if err != nil { + return nil, errorx.Wrapf(err, "publish resource '%d' failed", pluginID) + } + + return plugin, nil +} diff --git a/backend/application/plugin/playground.go b/backend/application/plugin/playground.go new file mode 100644 index 000000000..dc8e2e9c2 --- /dev/null +++ b/backend/application/plugin/playground.go @@ -0,0 +1,181 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plugin + +import ( + "context" + "fmt" + "strconv" + + pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/logs" +) + +func (p *PluginApplicationService) GetPlaygroundPluginList(ctx context.Context, req *pluginAPI.GetPlaygroundPluginListRequest) (resp *pluginAPI.GetPlaygroundPluginListResponse, err error) { + var ( + plugins []*entity.PluginInfo + total int64 + ) + if len(req.PluginIds) > 0 { + plugins, total, err = p.getPlaygroundPluginListByIDs(ctx, req.PluginIds) + } else { + plugins, total, err = p.getPlaygroundPluginList(ctx, req) + } + + if err != nil { + return nil, errorx.Wrapf(err, "getPlaygroundPluginList failed, req=%v", req) + } + + pluginList := make([]*common.PluginInfoForPlayground, 0, len(plugins)) + for _, pl := range plugins { + tools, err := p.toolRepo.GetPluginAllOnlineTools(ctx, pl.ID) + if err != nil { + return nil, errorx.Wrapf(err, "GetPluginAllOnlineTools failed, pluginID=%d", pl.ID) + } + + pluginInfo, err := p.toPluginInfoForPlayground(ctx, pl, tools) + if err != nil { + return nil, err + } + + pluginList = append(pluginList, pluginInfo) + } + + resp = &pluginAPI.GetPlaygroundPluginListResponse{ + Data: &common.GetPlaygroundPluginListData{ + Total: int32(total), + PluginList: pluginList, + }, + } + + return resp, nil +} + +func (p *PluginApplicationService) getPlaygroundPluginListByIDs(ctx context.Context, pluginIDs []string) (plugins []*entity.PluginInfo, total int64, err error) { + ids := make([]int64, 0, len(pluginIDs)) + for _, id := range pluginIDs { + pluginID, pErr := strconv.ParseInt(id, 10, 64) + if pErr != nil { + return nil, 0, fmt.Errorf("invalid pluginID '%s'", id) + } + ids = append(ids, pluginID) + } + + plugins, err = p.pluginRepo.MGetOnlinePlugins(ctx, ids) + if err != nil { + return nil, 0, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs) + } + + total = int64(len(plugins)) + + return plugins, total, nil +} + +func (p *PluginApplicationService) getPlaygroundPluginList(ctx context.Context, req *pluginAPI.GetPlaygroundPluginListRequest) (plugins []*entity.PluginInfo, total int64, err error) { + pageInfo := dto.PageInfo{ + Name: req.Name, + Page: int(req.GetPage()), + Size: int(req.GetSize()), + SortBy: func() *dto.SortField { + if req.GetOrderBy() == 0 { + return ptr.Of(dto.SortByUpdatedAt) + } + return ptr.Of(dto.SortByCreatedAt) + }(), + OrderByACS: ptr.Of(false), + } + plugins, total, err = p.DomainSVC.ListCustomOnlinePlugins(ctx, req.GetSpaceID(), pageInfo) + if err != nil { + return nil, 0, errorx.Wrapf(err, "ListCustomOnlinePlugins failed, spaceID=%d", req.GetSpaceID()) + } + + return plugins, total, nil +} + +func (p *PluginApplicationService) toPluginInfoForPlayground(ctx context.Context, pl *entity.PluginInfo, tools []*entity.ToolInfo) (*common.PluginInfoForPlayground, error) { + pluginAPIs := make([]*common.PluginApi, 0, len(tools)) + for _, tl := range tools { + params, err := tl.ToPluginParameters() + if err != nil { + return nil, err + } + + pluginAPIs = append(pluginAPIs, &common.PluginApi{ + APIID: strconv.FormatInt(tl.ID, 10), + Name: tl.GetName(), + Desc: tl.GetDesc(), + PluginID: strconv.FormatInt(pl.ID, 10), + PluginName: pl.GetName(), + RunMode: common.RunMode_Sync, + Parameters: params, + }) + } + + var creator *common.Creator + userInfo, err := p.userSVC.GetUserInfo(ctx, pl.DeveloperID) + if err != nil { + logs.CtxErrorf(ctx, "get user info failed, err=%v", err) + creator = common.NewCreator() + } else { + creator = &common.Creator{ + ID: strconv.FormatInt(pl.DeveloperID, 10), + Name: userInfo.Name, + AvatarURL: userInfo.IconURL, + UserUniqueName: userInfo.UniqueName, + } + } + + iconURL, err := p.oss.GetObjectUrl(ctx, pl.GetIconURI()) + if err != nil { + logs.Errorf("get plugin icon url failed, err=%v", err) + } + + authType, ok := convert.ToThriftAuthType(pl.GetAuthInfo().Type) + if !ok { + return nil, fmt.Errorf("invalid auth type '%s'", pl.GetAuthInfo().Type) + } + + pluginInfo := &common.PluginInfoForPlayground{ + Auth: int32(authType), + CreateTime: strconv.FormatInt(pl.CreatedAt/1000, 10), + CreationMethod: common.CreationMethod_COZE, + Creator: creator, + DescForHuman: pl.GetDesc(), + ID: strconv.FormatInt(pl.ID, 10), + IsOfficial: pl.IsOfficial(), + MaterialID: strconv.FormatInt(pl.ID, 10), + Name: pl.GetName(), + PluginIcon: iconURL, + PluginType: pl.PluginType, + SpaceID: strconv.FormatInt(pl.SpaceID, 10), + StatisticData: common.NewPluginStatisticData(), + Status: common.PluginStatus_SUBMITTED, + UpdateTime: strconv.FormatInt(pl.UpdatedAt/1000, 10), + ProjectID: strconv.FormatInt(pl.GetAPPID(), 10), + VersionName: pl.GetVersion(), + VersionTs: pl.GetVersion(), // Compatible with front-end logic, in theory VersionName should be used + PluginApis: pluginAPIs, + } + + return pluginInfo, nil +} diff --git a/backend/application/plugin/plugin.go b/backend/application/plugin/plugin.go index 4819c49df..4293bb1d5 100644 --- a/backend/application/plugin/plugin.go +++ b/backend/application/plugin/plugin.go @@ -18,46 +18,26 @@ package plugin import ( "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "os" - "strconv" "strings" "time" - "github.com/bytedance/sonic" - "github.com/getkin/kin-openapi/openapi3" - gonanoid "github.com/matoous/go-nanoid" - "gopkg.in/yaml.v3" - - botOpenAPI "github.com/coze-dev/coze-studio/backend/api/model/app/bot_open_api" productCommon "github.com/coze-dev/coze-studio/backend/api/model/marketplace/product_common" productAPI "github.com/coze-dev/coze-studio/backend/api/model/marketplace/product_public_api" pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" - "github.com/coze-dev/coze-studio/backend/application/base/pluginutil" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" - pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert/api" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" - "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" "github.com/coze-dev/coze-studio/backend/domain/plugin/service" - searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" search "github.com/coze-dev/coze-studio/backend/domain/search/service" user "github.com/coze-dev/coze-studio/backend/domain/user/service" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" "github.com/coze-dev/coze-studio/backend/pkg/errorx" - "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" - "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" - commonConsts "github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -73,697 +53,6 @@ type PluginApplicationService struct { pluginRepo repository.PluginRepository } -func (p *PluginApplicationService) GetOAuthSchema(ctx context.Context, req *pluginAPI.GetOAuthSchemaRequest) (resp *pluginAPI.GetOAuthSchemaResponse, err error) { - return &pluginAPI.GetOAuthSchemaResponse{ - OauthSchema: pluginConf.GetOAuthSchema(), - }, nil -} - -func (p *PluginApplicationService) GetPlaygroundPluginList(ctx context.Context, req *pluginAPI.GetPlaygroundPluginListRequest) (resp *pluginAPI.GetPlaygroundPluginListResponse, err error) { - var ( - plugins []*entity.PluginInfo - total int64 - ) - if len(req.PluginIds) > 0 { - plugins, total, err = p.getPlaygroundPluginListByIDs(ctx, req.PluginIds) - } else { - plugins, total, err = p.getPlaygroundPluginList(ctx, req) - } - - if err != nil { - return nil, errorx.Wrapf(err, "getPlaygroundPluginList failed, req=%v", req) - } - - pluginList := make([]*common.PluginInfoForPlayground, 0, len(plugins)) - for _, pl := range plugins { - tools, err := p.toolRepo.GetPluginAllOnlineTools(ctx, pl.ID) - if err != nil { - return nil, errorx.Wrapf(err, "GetPluginAllOnlineTools failed, pluginID=%d", pl.ID) - } - - pluginInfo, err := p.toPluginInfoForPlayground(ctx, pl, tools) - if err != nil { - return nil, err - } - - pluginList = append(pluginList, pluginInfo) - } - - resp = &pluginAPI.GetPlaygroundPluginListResponse{ - Data: &common.GetPlaygroundPluginListData{ - Total: int32(total), - PluginList: pluginList, - }, - } - - return resp, nil -} - -func (p *PluginApplicationService) getPlaygroundPluginListByIDs(ctx context.Context, pluginIDs []string) (plugins []*entity.PluginInfo, total int64, err error) { - ids := make([]int64, 0, len(pluginIDs)) - for _, id := range pluginIDs { - pluginID, err := strconv.ParseInt(id, 10, 64) - if err != nil { - return nil, 0, fmt.Errorf("invalid pluginID '%s'", id) - } - ids = append(ids, pluginID) - } - - plugins, err = p.pluginRepo.MGetOnlinePlugins(ctx, ids) - if err != nil { - return nil, 0, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs) - } - - total = int64(len(plugins)) - - return plugins, total, nil -} - -func (p *PluginApplicationService) getPlaygroundPluginList(ctx context.Context, req *pluginAPI.GetPlaygroundPluginListRequest) (plugins []*entity.PluginInfo, total int64, err error) { - pageInfo := entity.PageInfo{ - Name: req.Name, - Page: int(req.GetPage()), - Size: int(req.GetSize()), - SortBy: func() *entity.SortField { - if req.GetOrderBy() == 0 { - return ptr.Of(entity.SortByUpdatedAt) - } - return ptr.Of(entity.SortByCreatedAt) - }(), - OrderByACS: ptr.Of(false), - } - plugins, total, err = p.DomainSVC.ListCustomOnlinePlugins(ctx, req.GetSpaceID(), pageInfo) - if err != nil { - return nil, 0, errorx.Wrapf(err, "ListCustomOnlinePlugins failed, spaceID=%d", req.GetSpaceID()) - } - - return plugins, total, nil -} - -func (p *PluginApplicationService) toPluginInfoForPlayground(ctx context.Context, pl *entity.PluginInfo, tools []*entity.ToolInfo) (*common.PluginInfoForPlayground, error) { - pluginAPIs := make([]*common.PluginApi, 0, len(tools)) - for _, tl := range tools { - params, err := tl.ToPluginParameters() - if err != nil { - return nil, err - } - - pluginAPIs = append(pluginAPIs, &common.PluginApi{ - APIID: strconv.FormatInt(tl.ID, 10), - Name: tl.GetName(), - Desc: tl.GetDesc(), - PluginID: strconv.FormatInt(pl.ID, 10), - PluginName: pl.GetName(), - RunMode: common.RunMode_Sync, - Parameters: params, - }) - } - - var creator *common.Creator - userInfo, err := p.userSVC.GetUserInfo(context.Background(), pl.DeveloperID) - if err != nil { - logs.CtxErrorf(ctx, "get user info failed, err=%v", err) - creator = common.NewCreator() - } else { - creator = &common.Creator{ - ID: strconv.FormatInt(pl.DeveloperID, 10), - Name: userInfo.Name, - AvatarURL: userInfo.IconURL, - UserUniqueName: userInfo.UniqueName, - } - } - - iconURL, err := p.oss.GetObjectUrl(ctx, pl.GetIconURI()) - if err != nil { - logs.Errorf("get plugin icon url failed, err=%v", err) - } - - authType, ok := model.ToThriftAuthType(pl.GetAuthInfo().Type) - if !ok { - return nil, fmt.Errorf("invalid auth type '%s'", pl.GetAuthInfo().Type) - } - - pluginInfo := &common.PluginInfoForPlayground{ - Auth: int32(authType), - CreateTime: strconv.FormatInt(pl.CreatedAt/1000, 10), - CreationMethod: common.CreationMethod_COZE, - Creator: creator, - DescForHuman: pl.GetDesc(), - ID: strconv.FormatInt(pl.ID, 10), - IsOfficial: pl.IsOfficial(), - MaterialID: strconv.FormatInt(pl.ID, 10), - Name: pl.GetName(), - PluginIcon: iconURL, - PluginType: pl.PluginType, - SpaceID: strconv.FormatInt(pl.SpaceID, 10), - StatisticData: common.NewPluginStatisticData(), - Status: common.PluginStatus_SUBMITTED, - UpdateTime: strconv.FormatInt(pl.UpdatedAt/1000, 10), - ProjectID: strconv.FormatInt(pl.GetAPPID(), 10), - VersionName: pl.GetVersion(), - VersionTs: pl.GetVersion(), // Compatible with front-end logic, in theory VersionName should be used - PluginApis: pluginAPIs, - } - - return pluginInfo, nil -} - -func (p *PluginApplicationService) RegisterPluginMeta(ctx context.Context, req *pluginAPI.RegisterPluginMetaRequest) (resp *pluginAPI.RegisterPluginMetaResponse, err error) { - userID := ctxutil.GetUIDFromCtx(ctx) - if userID == nil { - return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) - } - - _authType, ok := model.ToAuthType(req.GetAuthType()) - if !ok { - return nil, fmt.Errorf("invalid auth type '%d'", req.GetAuthType()) - } - authType := ptr.Of(_authType) - - var authSubType *model.AuthzSubType - if req.SubAuthType != nil { - _authSubType, ok := model.ToAuthSubType(req.GetSubAuthType()) - if !ok { - return nil, fmt.Errorf("invalid sub authz type '%d'", req.GetSubAuthType()) - } - authSubType = ptr.Of(_authSubType) - } - - var loc model.HTTPParamLocation - if *authType == model.AuthzTypeOfService { - if req.GetLocation() == common.AuthorizationServiceLocation_Query { - loc = model.ParamInQuery - } else if req.GetLocation() == common.AuthorizationServiceLocation_Header { - loc = model.ParamInHeader - } else { - return nil, fmt.Errorf("invalid location '%s'", req.GetLocation()) - } - } - - r := &dto.CreateDraftPluginRequest{ - PluginType: req.GetPluginType(), - SpaceID: req.GetSpaceID(), - DeveloperID: *userID, - IconURI: req.Icon.URI, - ProjectID: req.ProjectID, - Name: req.GetName(), - Desc: req.GetDesc(), - ServerURL: req.GetURL(), - CommonParams: req.CommonParams, - AuthInfo: &dto.PluginAuthInfo{ - AuthzType: authType, - Location: ptr.Of(loc), - Key: req.Key, - ServiceToken: req.ServiceToken, - OAuthInfo: req.OauthInfo, - AuthzSubType: authSubType, - AuthzPayload: req.AuthPayload, - }, - } - pluginID, err := p.DomainSVC.CreateDraftPlugin(ctx, r) - if err != nil { - return nil, errorx.Wrapf(err, "CreateDraftPlugin failed") - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Created, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResSubType: ptr.Of(int32(req.GetPluginType())), - ResID: pluginID, - Name: &req.Name, - SpaceID: &req.SpaceID, - APPID: req.ProjectID, - OwnerID: userID, - PublishStatus: ptr.Of(resCommon.PublishStatus_UnPublished), - CreateTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - return nil, fmt.Errorf("publish resource '%d' failed, err=%v", pluginID, err) - } - - resp = &pluginAPI.RegisterPluginMetaResponse{ - PluginID: pluginID, - } - - return resp, nil -} - -func (p *PluginApplicationService) RegisterPlugin(ctx context.Context, req *pluginAPI.RegisterPluginRequest) (resp *pluginAPI.RegisterPluginResponse, err error) { - userID := ctxutil.GetUIDFromCtx(ctx) - if userID == nil { - return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) - } - - mf := &entity.PluginManifest{} - err = sonic.UnmarshalString(req.AiPlugin, &mf) - if err != nil { - return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, err.Error())) - } - - mf.LogoURL = commonConsts.DefaultPluginIcon - - doc, err := openapi3.NewLoader().LoadFromData([]byte(req.Openapi)) - if err != nil { - return nil, errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error())) - } - - res, err := p.DomainSVC.CreateDraftPluginWithCode(ctx, &dto.CreateDraftPluginWithCodeRequest{ - SpaceID: req.GetSpaceID(), - DeveloperID: *userID, - ProjectID: req.ProjectID, - Manifest: mf, - OpenapiDoc: ptr.Of(model.Openapi3T(*doc)), - }) - if err != nil { - return nil, errorx.Wrapf(err, "CreateDraftPluginWithCode failed") - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Created, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResSubType: ptr.Of(int32(res.Plugin.PluginType)), - ResID: res.Plugin.ID, - Name: ptr.Of(res.Plugin.GetName()), - APPID: req.ProjectID, - SpaceID: &req.SpaceID, - OwnerID: userID, - PublishStatus: ptr.Of(resCommon.PublishStatus_UnPublished), - CreateTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - return nil, fmt.Errorf("publish resource '%d' failed, err=%v", res.Plugin.ID, err) - } - - resp = &pluginAPI.RegisterPluginResponse{ - Data: &common.RegisterPluginData{ - PluginID: res.Plugin.ID, - Openapi: req.Openapi, - }, - } - - return resp, nil -} - -func (p *PluginApplicationService) GetPluginAPIs(ctx context.Context, req *pluginAPI.GetPluginAPIsRequest) (resp *pluginAPI.GetPluginAPIsResponse, err error) { - pl, err := p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateGetPluginAPIsRequest failed") - } - - var ( - draftTools []*entity.ToolInfo - total int64 - ) - if len(req.APIIds) > 0 { - toolIDs := make([]int64, 0, len(req.APIIds)) - for _, id := range req.APIIds { - toolID, err := strconv.ParseInt(id, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid tool id '%s'", id) - } - toolIDs = append(toolIDs, toolID) - } - - draftTools, err = p.toolRepo.MGetDraftTools(ctx, toolIDs) - if err != nil { - return nil, errorx.Wrapf(err, "MGetDraftTools failed, toolIDs=%v", toolIDs) - } - - total = int64(len(draftTools)) - - } else { - pageInfo := entity.PageInfo{ - Page: int(req.Page), - Size: int(req.Size), - SortBy: ptr.Of(entity.SortByCreatedAt), - OrderByACS: ptr.Of(false), - } - draftTools, total, err = p.toolRepo.ListPluginDraftTools(ctx, req.PluginID, pageInfo) - if err != nil { - return nil, errorx.Wrapf(err, "ListPluginDraftTools failed, pluginID=%d", req.PluginID) - } - } - - if len(draftTools) == 0 { - return &pluginAPI.GetPluginAPIsResponse{ - APIInfo: make([]*common.PluginAPIInfo, 0), - Total: 0, - }, nil - } - - draftToolIDs := slices.Transform(draftTools, func(tl *entity.ToolInfo) int64 { - return tl.ID - }) - onlineStatus, err := p.getToolOnlineStatus(ctx, draftToolIDs) - if err != nil { - return nil, err - } - - apis := make([]*common.PluginAPIInfo, 0, len(draftTools)) - for _, tool := range draftTools { - method, ok := model.ToThriftAPIMethod(tool.GetMethod()) - if !ok { - return nil, fmt.Errorf("invalid method '%s'", tool.GetMethod()) - } - reqParams, err := tool.ToReqAPIParameter() - if err != nil { - return nil, err - } - respParams, err := tool.ToRespAPIParameter() - if err != nil { - return nil, err - } - - var apiExtend *common.APIExtend - if tmp, ok := tool.Operation.Extensions[model.APISchemaExtendAuthMode].(string); ok { - if mode, ok := model.ToThriftAPIAuthMode(model.ToolAuthMode(tmp)); ok { - apiExtend = &common.APIExtend{ - AuthMode: mode, - } - } - } - - api := &common.PluginAPIInfo{ - APIID: strconv.FormatInt(tool.ID, 10), - CreateTime: strconv.FormatInt(tool.CreatedAt/1000, 10), - DebugStatus: tool.GetDebugStatus(), - Desc: tool.GetDesc(), - Disabled: func() bool { - if tool.GetActivatedStatus() == model.DeactivateTool { - return true - } - return false - }(), - Method: method, - Name: tool.GetName(), - OnlineStatus: onlineStatus[tool.ID], - Path: tool.GetSubURL(), - PluginID: strconv.FormatInt(tool.PluginID, 10), - RequestParams: reqParams, - ResponseParams: respParams, - StatisticData: common.NewPluginStatisticData(), - APIExtend: apiExtend, - } - example := pl.GetToolExample(ctx, tool.GetName()) - if example != nil { - api.DebugExample = &common.DebugExample{ - ReqExample: example.RequestExample, - RespExample: example.ResponseExample, - } - api.DebugExampleStatus = common.DebugExampleStatus_Enable - } - - apis = append(apis, api) - } - - resp = &pluginAPI.GetPluginAPIsResponse{ - APIInfo: apis, - Total: int32(total), - } - - return resp, nil -} - -func (p *PluginApplicationService) getToolOnlineStatus(ctx context.Context, toolIDs []int64) (map[int64]common.OnlineStatus, error) { - onlineTools, err := p.toolRepo.MGetOnlineTools(ctx, toolIDs, repository.WithToolID()) - if err != nil { - return nil, errorx.Wrapf(err, "MGetOnlineTools failed, toolIDs=%v", toolIDs) - } - - onlineStatus := make(map[int64]common.OnlineStatus, len(onlineTools)) - for _, tool := range onlineTools { - onlineStatus[tool.ID] = common.OnlineStatus_ONLINE - } - - return onlineStatus, nil -} - -func (p *PluginApplicationService) GetPluginInfo(ctx context.Context, req *pluginAPI.GetPluginInfoRequest) (resp *pluginAPI.GetPluginInfoResponse, err error) { - draftPlugin, err := p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateGetPluginInfoRequest failed") - } - - metaInfo, err := p.getPluginMetaInfo(ctx, draftPlugin) - if err != nil { - return nil, err - } - - codeInfo, err := p.getPluginCodeInfo(ctx, draftPlugin) - if err != nil { - return nil, err - } - - _, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, req.PluginID, repository.WithPluginID()) - if err != nil { - return nil, errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", req.PluginID) - } - - resp = &pluginAPI.GetPluginInfoResponse{ - MetaInfo: metaInfo, - CodeInfo: codeInfo, - Creator: common.NewCreator(), - StatisticData: common.NewPluginStatisticData(), - PluginType: draftPlugin.PluginType, - CreationMethod: common.CreationMethod_COZE, - Published: exist, - } - - return resp, nil -} - -func (p *PluginApplicationService) getPluginCodeInfo(ctx context.Context, draftPlugin *entity.PluginInfo) (*common.CodeInfo, error) { - tools, err := p.toolRepo.GetPluginAllDraftTools(ctx, draftPlugin.ID) - if err != nil { - return nil, errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", draftPlugin.ID) - } - - paths := openapi3.Paths{} - for _, tool := range tools { - if tool.GetActivatedStatus() == model.DeactivateTool { - continue - } - item := &openapi3.PathItem{} - item.SetOperation(tool.GetMethod(), tool.Operation.Operation) - paths[tool.GetSubURL()] = item - } - draftPlugin.OpenapiDoc.Paths = paths - - manifestStr, err := sonic.MarshalString(draftPlugin.Manifest) - if err != nil { - return nil, fmt.Errorf("marshal manifest failed, err=%v", err) - } - - docBytes, err := yaml.Marshal(draftPlugin.OpenapiDoc) - if err != nil { - return nil, fmt.Errorf("marshal openapi doc failed, err=%v", err) - } - - codeInfo := &common.CodeInfo{ - OpenapiDesc: string(docBytes), - PluginDesc: manifestStr, - } - - return codeInfo, nil -} - -func (p *PluginApplicationService) getPluginMetaInfo(ctx context.Context, draftPlugin *entity.PluginInfo) (*common.PluginMetaInfo, error) { - commonParams := make(map[common.ParameterLocation][]*common.CommonParamSchema, len(draftPlugin.Manifest.CommonParams)) - for loc, params := range draftPlugin.Manifest.CommonParams { - location, ok := model.ToThriftHTTPParamLocation(loc) - if !ok { - return nil, fmt.Errorf("invalid location '%s'", loc) - } - commonParams[location] = make([]*common.CommonParamSchema, 0, len(params)) - for _, param := range params { - commonParams[location] = append(commonParams[location], &common.CommonParamSchema{ - Name: param.Name, - Value: param.Value, - }) - } - } - - iconURL, err := p.oss.GetObjectUrl(ctx, draftPlugin.GetIconURI()) - if err != nil { - logs.CtxWarnf(ctx, "get icon url with '%s' failed, err=%v", draftPlugin.GetIconURI(), err) - } - - metaInfo := &common.PluginMetaInfo{ - Name: draftPlugin.GetName(), - Desc: draftPlugin.GetDesc(), - URL: draftPlugin.GetServerURL(), - Icon: &common.PluginIcon{ - URI: draftPlugin.GetIconURI(), - URL: iconURL, - }, - CommonParams: commonParams, - } - - err = p.fillAuthInfoInMetaInfo(ctx, draftPlugin, metaInfo) - if err != nil { - return nil, errorx.Wrapf(err, "fillAuthInfoInMetaInfo failed, pluginID=%d", draftPlugin.ID) - } - - return metaInfo, nil -} - -func (p *PluginApplicationService) fillAuthInfoInMetaInfo(ctx context.Context, draftPlugin *entity.PluginInfo, metaInfo *common.PluginMetaInfo) (err error) { - authInfo := draftPlugin.GetAuthInfo() - authType, ok := model.ToThriftAuthType(authInfo.Type) - if !ok { - return fmt.Errorf("invalid auth type '%s'", authInfo.Type) - } - - var subAuthType *int32 - if authInfo.SubType != "" { - _subAuthType, ok := model.ToThriftAuthSubType(authInfo.SubType) - if !ok { - return fmt.Errorf("invalid sub authz type '%s'", authInfo.SubType) - } - subAuthType = &_subAuthType - } - - metaInfo.AuthType = append(metaInfo.AuthType, authType) - metaInfo.SubAuthType = subAuthType - - if authType == common.AuthorizationType_None { - return nil - } - - if authType == common.AuthorizationType_Service { - var loc common.AuthorizationServiceLocation - _loc := model.HTTPParamLocation(strings.ToLower(string(authInfo.AuthOfAPIToken.Location))) - if _loc == model.ParamInHeader { - loc = common.AuthorizationServiceLocation_Header - } else if _loc == model.ParamInQuery { - loc = common.AuthorizationServiceLocation_Query - } else { - return fmt.Errorf("invalid location '%s'", authInfo.AuthOfAPIToken.Location) - } - - metaInfo.Location = ptr.Of(loc) - metaInfo.Key = ptr.Of(authInfo.AuthOfAPIToken.Key) - metaInfo.ServiceToken = ptr.Of(authInfo.AuthOfAPIToken.ServiceToken) - } - - if authType == common.AuthorizationType_OAuth { - metaInfo.OauthInfo = &authInfo.Payload - } - - return nil -} - -func (p *PluginApplicationService) GetUpdatedAPIs(ctx context.Context, req *pluginAPI.GetUpdatedAPIsRequest) (resp *pluginAPI.GetUpdatedAPIsResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateGetUpdatedAPIsRequest failed") - } - - draftTools, err := p.toolRepo.GetPluginAllDraftTools(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", req.PluginID) - } - onlineTools, err := p.toolRepo.GetPluginAllOnlineTools(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "GetPluginAllOnlineTools failed, pluginID=%d", req.PluginID) - } - - var updatedToolName, createdToolName, delToolName []string - - draftMap := slices.ToMap(draftTools, func(e *entity.ToolInfo) (string, *entity.ToolInfo) { - return e.GetName(), e - }) - onlineMap := slices.ToMap(onlineTools, func(e *entity.ToolInfo) (string, *entity.ToolInfo) { - return e.GetName(), e - }) - - for name := range draftMap { - if _, ok := onlineMap[name]; !ok { - createdToolName = append(createdToolName, name) - } - } - - for name, ot := range onlineMap { - dt, ok := draftMap[name] - if !ok { - delToolName = append(delToolName, name) - continue - } - - if ot.GetMethod() != dt.GetMethod() || - ot.GetSubURL() != dt.GetSubURL() || - ot.GetDesc() != dt.GetDesc() { - updatedToolName = append(updatedToolName, name) - continue - } - - os, err := sonic.MarshalString(ot.Operation) - if err != nil { - logs.CtxErrorf(ctx, "marshal online tool operation failed, toolID=%d, err=%v", ot.ID, err) - - updatedToolName = append(updatedToolName, name) - continue - } - ds, err := sonic.MarshalString(dt.Operation) - if err != nil { - logs.CtxErrorf(ctx, "marshal draft tool operation failed, toolID=%d, err=%v", ot.ID, err) - - updatedToolName = append(updatedToolName, name) - continue - } - - if os != ds { - updatedToolName = append(updatedToolName, name) - } - } - - resp = &pluginAPI.GetUpdatedAPIsResponse{ - UpdatedAPINames: updatedToolName, - CreatedAPINames: createdToolName, - DeletedAPINames: delToolName, - } - - return resp, nil -} - -func (p *PluginApplicationService) GetUserAuthority(ctx context.Context, req *pluginAPI.GetUserAuthorityRequest) (resp *pluginAPI.GetUserAuthorityResponse, err error) { - resp = &pluginAPI.GetUserAuthorityResponse{ - Data: &common.GetUserAuthorityData{ - CanEdit: true, - CanRead: true, - CanDelete: true, - CanDebug: true, - CanPublish: true, - CanReadChangelog: true, - }, - } - - return resp, nil -} - -func (p *PluginApplicationService) GetOAuthStatus(ctx context.Context, req *pluginAPI.GetOAuthStatusRequest) (resp *pluginAPI.GetOAuthStatusResponse, err error) { - userID := ctxutil.GetUIDFromCtx(ctx) - if userID == nil { - return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) - } - - res, err := p.DomainSVC.GetOAuthStatus(ctx, *userID, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "GetOAuthStatus failed, pluginID=%d", req.PluginID) - } - resp = &pluginAPI.GetOAuthStatusResponse{ - IsOauth: res.IsOauth, - Status: res.Status, - Content: res.OAuthURL, - } - - return resp, nil -} - func (p *PluginApplicationService) CheckAndLockPluginEdit(ctx context.Context, req *pluginAPI.CheckAndLockPluginEditRequest) (resp *pluginAPI.CheckAndLockPluginEditResponse, err error) { resp = &pluginAPI.CheckAndLockPluginEditResponse{ Data: &common.CheckAndLockPluginEditData{ @@ -774,313 +63,6 @@ func (p *PluginApplicationService) CheckAndLockPluginEdit(ctx context.Context, r return resp, nil } -func (p *PluginApplicationService) CreateAPI(ctx context.Context, req *pluginAPI.CreateAPIRequest) (resp *pluginAPI.CreateAPIResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateCreateAPIRequest failed") - } - - defaultSubURL := gonanoid.MustID(6) - - tool := &entity.ToolInfo{ - PluginID: req.PluginID, - ActivatedStatus: ptr.Of(model.ActivateTool), - DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting), - SubURL: ptr.Of("/" + defaultSubURL), - Method: ptr.Of(http.MethodGet), - Operation: model.NewOpenapi3Operation(&openapi3.Operation{ - Summary: req.Desc, - OperationID: req.Name, - Parameters: []*openapi3.ParameterRef{}, - RequestBody: entity.DefaultOpenapi3RequestBody(), - Responses: entity.DefaultOpenapi3Responses(), - Extensions: map[string]any{}, - }), - } - - toolID, err := p.toolRepo.CreateDraftTool(ctx, tool) - if err != nil { - return nil, errorx.Wrapf(err, "CreateDraftTool failed, pluginID=%d", req.PluginID) - } - - resp = &pluginAPI.CreateAPIResponse{ - APIID: strconv.FormatInt(toolID, 10), - } - - return resp, nil -} - -func (p *PluginApplicationService) UpdateAPI(ctx context.Context, req *pluginAPI.UpdateAPIRequest) (resp *pluginAPI.UpdateAPIResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateUpdateAPIRequest failed") - } - - op, err := pluginutil.APIParamsToOpenapiOperation(req.RequestParams, req.ResponseParams) - if err != nil { - return nil, err - } - - var method *string - if m, ok := model.ToHTTPMethod(req.GetMethod()); ok { - method = &m - } - - updateReq := &dto.UpdateDraftToolRequest{ - PluginID: req.PluginID, - ToolID: req.APIID, - Name: req.Name, - Desc: req.Desc, - SubURL: req.Path, - Method: method, - Parameters: op.Parameters, - RequestBody: op.RequestBody, - Responses: op.Responses, - Disabled: req.Disabled, - SaveExample: req.SaveExample, - DebugExample: req.DebugExample, - APIExtend: req.APIExtend, - } - err = p.DomainSVC.UpdateDraftTool(ctx, updateReq) - if err != nil { - return nil, errorx.Wrapf(err, "UpdateDraftTool failed, pluginID=%d, toolID=%d", updateReq.PluginID, updateReq.ToolID) - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Updated, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: req.PluginID, - UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) - } - - resp = &pluginAPI.UpdateAPIResponse{} - - return resp, nil -} - -func (p *PluginApplicationService) UpdatePlugin(ctx context.Context, req *pluginAPI.UpdatePluginRequest) (resp *pluginAPI.UpdatePluginResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateUpdatePluginRequest failed") - } - - userID := ctxutil.GetUIDFromCtx(ctx) - - loader := openapi3.NewLoader() - _doc, err := loader.LoadFromData([]byte(req.Openapi)) - if err != nil { - return nil, errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error())) - } - - doc := ptr.Of(model.Openapi3T(*_doc)) - - manifest := &entity.PluginManifest{} - err = sonic.UnmarshalString(req.AiPlugin, manifest) - if err != nil { - return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, err.Error())) - } - - err = p.DomainSVC.UpdateDraftPluginWithCode(ctx, &dto.UpdateDraftPluginWithCodeRequest{ - UserID: *userID, - PluginID: req.PluginID, - OpenapiDoc: doc, - Manifest: manifest, - }) - if err != nil { - return nil, errorx.Wrapf(err, "UpdateDraftPluginWithCode failed, pluginID=%d", req.PluginID) - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Updated, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: req.PluginID, - Name: &manifest.NameForHuman, - UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) - } - - resp = &pluginAPI.UpdatePluginResponse{ - Data: &common.UpdatePluginData{ - Res: true, - }, - } - - return resp, nil -} - -func (p *PluginApplicationService) DeleteAPI(ctx context.Context, req *pluginAPI.DeleteAPIRequest) (resp *pluginAPI.DeleteAPIResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateDeleteAPIRequest failed") - } - - err = p.toolRepo.DeleteDraftTool(ctx, req.APIID) - if err != nil { - return nil, errorx.Wrapf(err, "DeleteDraftTool failed, toolID=%d", req.APIID) - } - - resp = &pluginAPI.DeleteAPIResponse{} - - return resp, nil -} - -func (p *PluginApplicationService) DelPlugin(ctx context.Context, req *pluginAPI.DelPluginRequest) (resp *pluginAPI.DelPluginResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateDelPluginRequest failed") - } - - err = p.DomainSVC.DeleteDraftPlugin(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "DeleteDraftPlugin failed, pluginID=%d", req.PluginID) - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Deleted, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: req.PluginID, - UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - return nil, errorx.Wrapf(err, "publish resource '%d' failed", req.PluginID) - } - - resp = &pluginAPI.DelPluginResponse{} - - return resp, nil -} - -func (p *PluginApplicationService) PublishPlugin(ctx context.Context, req *pluginAPI.PublishPluginRequest) (resp *pluginAPI.PublishPluginResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validatePublishPluginRequest failed") - } - - err = p.DomainSVC.PublishPlugin(ctx, &model.PublishPluginRequest{ - PluginID: req.PluginID, - Version: req.VersionName, - VersionDesc: req.VersionDesc, - }) - if err != nil { - return nil, errorx.Wrapf(err, "PublishPlugin failed, pluginID=%d", req.PluginID) - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Updated, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: req.PluginID, - PublishStatus: ptr.Of(resCommon.PublishStatus_Published), - PublishTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) - } - - resp = &pluginAPI.PublishPluginResponse{} - - return resp, nil -} - -func (p *PluginApplicationService) UpdatePluginMeta(ctx context.Context, req *pluginAPI.UpdatePluginMetaRequest) (resp *pluginAPI.UpdatePluginMetaResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateUpdatePluginMetaRequest failed") - } - - authInfo, err := getUpdateAuthInfo(ctx, req) - if err != nil { - return nil, err - } - - updateReq := &dto.UpdateDraftPluginRequest{ - PluginID: req.PluginID, - Name: req.Name, - Desc: req.Desc, - URL: req.URL, - Icon: req.Icon, - CommonParams: req.CommonParams, - AuthInfo: authInfo, - } - err = p.DomainSVC.UpdateDraftPlugin(ctx, updateReq) - if err != nil { - return nil, errorx.Wrapf(err, "UpdateDraftPlugin failed, pluginID=%d", req.PluginID) - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Updated, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: req.PluginID, - Name: req.Name, - UpdateTimeMS: ptr.Of(time.Now().UnixMilli()), - }, - }) - if err != nil { - logs.CtxErrorf(ctx, "publish resource '%d' failed, err=%v", req.PluginID, err) - } - - resp = &pluginAPI.UpdatePluginMetaResponse{} - - return resp, nil -} - -func getUpdateAuthInfo(ctx context.Context, req *pluginAPI.UpdatePluginMetaRequest) (authInfo *dto.PluginAuthInfo, err error) { - if req.AuthType == nil { - return nil, nil - } - - _authType, ok := model.ToAuthType(req.GetAuthType()) - if !ok { - return nil, fmt.Errorf("invalid auth type '%d'", req.GetAuthType()) - } - authType := &_authType - - var authSubType *model.AuthzSubType - if req.SubAuthType != nil { - _authSubType, ok := model.ToAuthSubType(req.GetSubAuthType()) - if !ok { - return nil, fmt.Errorf("invalid sub authz type '%d'", req.GetSubAuthType()) - } - authSubType = &_authSubType - } - - var location *model.HTTPParamLocation - if req.Location != nil { - if *req.Location == common.AuthorizationServiceLocation_Header { - location = ptr.Of(model.ParamInHeader) - } else if *req.Location == common.AuthorizationServiceLocation_Query { - location = ptr.Of(model.ParamInQuery) - } else { - return nil, fmt.Errorf("invalid location '%d'", req.GetLocation()) - } - } - - authInfo = &dto.PluginAuthInfo{ - AuthzType: authType, - Location: location, - Key: req.Key, - ServiceToken: req.ServiceToken, - OAuthInfo: req.OauthInfo, - AuthzSubType: authSubType, - AuthzPayload: req.AuthPayload, - } - - return authInfo, nil -} - func (p *PluginApplicationService) GetBotDefaultParams(ctx context.Context, req *pluginAPI.GetBotDefaultParamsRequest) (resp *pluginAPI.GetBotDefaultParamsResponse, err error) { _, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, req.PluginID, repository.WithPluginID()) if err != nil { @@ -1113,7 +95,7 @@ func (p *PluginApplicationService) GetBotDefaultParams(ctx context.Context, req } func (p *PluginApplicationService) UpdateBotDefaultParams(ctx context.Context, req *pluginAPI.UpdateBotDefaultParamsRequest) (resp *pluginAPI.UpdateBotDefaultParamsResponse, err error) { - op, err := pluginutil.APIParamsToOpenapiOperation(req.RequestParams, req.ResponseParams) + op, err := api.APIParamsToOpenapiOperation(req.RequestParams, req.ResponseParams) if err != nil { return nil, err } @@ -1135,81 +117,6 @@ func (p *PluginApplicationService) UpdateBotDefaultParams(ctx context.Context, r return resp, nil } -func (p *PluginApplicationService) DebugAPI(ctx context.Context, req *pluginAPI.DebugAPIRequest) (resp *pluginAPI.DebugAPIResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateDebugAPIRequest failed") - } - - const defaultErrReason = "internal server error" - - userID := ctxutil.GetUIDFromCtx(ctx) - if userID == nil { - return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) - } - - resp = &pluginAPI.DebugAPIResponse{ - Success: false, - RawReq: "{}", - RawResp: "{}", - Resp: "{}", - } - - opts := []model.ExecuteToolOpt{} - switch req.Operation { - case common.DebugOperation_Debug: - opts = append(opts, model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnErr)) - case common.DebugOperation_Parse: - opts = append(opts, model.WithAutoGenRespSchema(), - model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnRaw), - ) - } - - res, err := p.DomainSVC.ExecuteTool(ctx, &model.ExecuteToolRequest{ - UserID: conv.Int64ToStr(*userID), - PluginID: req.PluginID, - ToolID: req.APIID, - ExecScene: model.ExecSceneOfToolDebug, - ExecDraftTool: true, - ArgumentsInJson: req.Parameters, - }, opts...) - if err != nil { - var e errorx.StatusError - if errors.As(err, &e) { - resp.Reason = e.Msg() - return resp, nil - } - - logs.CtxErrorf(ctx, "ExecuteTool failed, err=%v", err) - resp.Reason = defaultErrReason - - return resp, nil - } - - resp = &pluginAPI.DebugAPIResponse{ - Success: true, - Resp: res.TrimmedResp, - RawReq: res.Request, - RawResp: res.RawResp, - ResponseParams: []*common.APIParameter{}, - } - - if req.Operation == common.DebugOperation_Parse { - res.Tool.Operation.Responses = res.RespSchema - } - - respParams, err := res.Tool.ToRespAPIParameter() - if err != nil { - logs.CtxErrorf(ctx, "ToRespAPIParameter failed, err=%v", err) - resp.Success = false - resp.Reason = defaultErrReason - } else { - resp.ResponseParams = respParams - } - - return resp, nil -} - func (p *PluginApplicationService) UnlockPluginEdit(ctx context.Context, req *pluginAPI.UnlockPluginEditRequest) (resp *pluginAPI.UnlockPluginEditResponse, err error) { resp = &pluginAPI.UnlockPluginEditResponse{ Released: true, @@ -1346,7 +253,7 @@ func (p *PluginApplicationService) buildPluginProductExtraInfo(ctx context.Conte authMode := ptr.Of(productAPI.PluginAuthMode_NoAuth) if authInfo != nil { - if authInfo.Type == model.AuthzTypeOfService || authInfo.Type == model.AuthzTypeOfOAuth { + if authInfo.Type == consts.AuthzTypeOfService || authInfo.Type == consts.AuthzTypeOfOAuth { authMode = ptr.Of(productAPI.PluginAuthMode_Required) err := plugin.Manifest.Validate(false) if err != nil { @@ -1390,256 +297,6 @@ func (p *PluginApplicationService) PublicGetProductDetail(ctx context.Context, r return resp, nil } -func (p *PluginApplicationService) GetPluginNextVersion(ctx context.Context, req *pluginAPI.GetPluginNextVersionRequest) (resp *pluginAPI.GetPluginNextVersionResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateGetPluginNextVersionRequest failed") - } - - nextVersion, err := p.DomainSVC.GetPluginNextVersion(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "GetPluginNextVersion failed, pluginID=%d", req.PluginID) - } - resp = &pluginAPI.GetPluginNextVersionResponse{ - NextVersionName: nextVersion, - } - return resp, nil -} - -func (p *PluginApplicationService) GetDevPluginList(ctx context.Context, req *pluginAPI.GetDevPluginListRequest) (resp *pluginAPI.GetDevPluginListResponse, err error) { - pageInfo := entity.PageInfo{ - Name: req.Name, - Page: int(req.GetPage()), - Size: int(req.GetSize()), - OrderByACS: ptr.Of(false), - } - if req.GetOrderBy() == common.OrderBy_UpdateTime { - pageInfo.SortBy = ptr.Of(entity.SortByUpdatedAt) - } else { - pageInfo.SortBy = ptr.Of(entity.SortByCreatedAt) - } - - res, err := p.DomainSVC.ListDraftPlugins(ctx, &dto.ListDraftPluginsRequest{ - SpaceID: req.SpaceID, - APPID: req.ProjectID, - PageInfo: pageInfo, - }) - if err != nil { - return nil, errorx.Wrapf(err, "ListDraftPlugins failed, spaceID=%d, appID=%d", req.SpaceID, req.ProjectID) - } - - pluginList := make([]*common.PluginInfoForPlayground, 0, len(res.Plugins)) - for _, pl := range res.Plugins { - tools, err := p.toolRepo.GetPluginAllDraftTools(ctx, pl.ID) - if err != nil { - return nil, errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", pl.ID) - } - - pluginInfo, err := p.toPluginInfoForPlayground(ctx, pl, tools) - if err != nil { - return nil, err - } - - pluginInfo.VersionTs = "0" // when you get the plugin information in the project, version ts is set to 0 by default - pluginList = append(pluginList, pluginInfo) - } - - resp = &pluginAPI.GetDevPluginListResponse{ - PluginList: pluginList, - Total: res.Total, - } - - return resp, nil -} - -func (p *PluginApplicationService) DeleteAPPAllPlugins(ctx context.Context, appID int64) (err error) { - pluginIDs, err := p.DomainSVC.DeleteAPPAllPlugins(ctx, appID) - if err != nil { - return errorx.Wrapf(err, "DeleteAPPAllPlugins failed, appID=%d", appID) - } - - for _, id := range pluginIDs { - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Deleted, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: id, - }, - }) - if err != nil { - return errorx.Wrapf(err, "publish resource '%d' failed", id) - } - } - - return nil -} - -func (p *PluginApplicationService) Convert2OpenAPI(ctx context.Context, req *pluginAPI.Convert2OpenAPIRequest) (resp *pluginAPI.Convert2OpenAPIResponse, err error) { - res := p.DomainSVC.ConvertToOpenapi3Doc(ctx, &dto.ConvertToOpenapi3DocRequest{ - RawInput: req.Data, - PluginServerURL: req.PluginURL, - }) - - if res.ErrMsg != "" { - return &pluginAPI.Convert2OpenAPIResponse{ - Code: errno.ErrPluginInvalidThirdPartyCode, - Msg: res.ErrMsg, - DuplicateAPIInfos: []*common.DuplicateAPIInfo{}, - PluginDataFormat: ptr.Of(res.Format), - }, nil - } - - doc, err := yaml.Marshal(res.OpenapiDoc) - if err != nil { - return nil, fmt.Errorf("marshal openapi doc failed, err=%v", err) - } - mf, err := json.Marshal(res.Manifest) - if err != nil { - return nil, fmt.Errorf("marshal manifest failed, err=%v", err) - } - - resp = &pluginAPI.Convert2OpenAPIResponse{ - PluginDataFormat: ptr.Of(res.Format), - Openapi: ptr.Of(string(doc)), - AiPlugin: ptr.Of(string(mf)), - DuplicateAPIInfos: []*common.DuplicateAPIInfo{}, - } - - return resp, nil -} - -func (p *PluginApplicationService) BatchCreateAPI(ctx context.Context, req *pluginAPI.BatchCreateAPIRequest) (resp *pluginAPI.BatchCreateAPIResponse, err error) { - _, err = p.validateDraftPluginAccess(ctx, req.PluginID) - if err != nil { - return nil, errorx.Wrapf(err, "validateBatchCreateAPIRequest failed") - } - - loader := openapi3.NewLoader() - doc, err := loader.LoadFromData([]byte(req.Openapi)) - if err != nil { - return nil, errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error())) - } - - res, err := p.DomainSVC.CreateDraftToolsWithCode(ctx, &dto.CreateDraftToolsWithCodeRequest{ - PluginID: req.PluginID, - OpenapiDoc: ptr.Of(model.Openapi3T(*doc)), - ConflictAndUpdate: req.ReplaceSamePaths, - }) - if err != nil { - return nil, errorx.Wrapf(err, "CreateDraftToolsWithCode failed, pluginID=%d", req.PluginID) - } - - duplicated := slices.Transform(res.DuplicatedTools, func(e entity.UniqueToolAPI) *common.PluginAPIInfo { - method, _ := model.ToThriftAPIMethod(e.Method) - return &common.PluginAPIInfo{ - Path: e.SubURL, - Method: method, - } - }) - - resp = &pluginAPI.BatchCreateAPIResponse{ - PathsDuplicated: duplicated, - } - - if len(duplicated) > 0 { - resp.Code = errno.ErrPluginDuplicatedTool - } - - return resp, nil -} - -func (p *PluginApplicationService) RevokeAuthToken(ctx context.Context, req *pluginAPI.RevokeAuthTokenRequest) (resp *pluginAPI.RevokeAuthTokenResponse, err error) { - userID := ctxutil.GetUIDFromCtx(ctx) - if userID == nil { - return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) - } - - err = p.DomainSVC.RevokeAccessToken(ctx, &entity.AuthorizationCodeMeta{ - UserID: conv.Int64ToStr(*userID), - PluginID: req.PluginID, - IsDraft: req.GetBotID() == 0, - }) - if err != nil { - return nil, errorx.Wrapf(err, "RevokeAccessToken failed, pluginID=%d", req.PluginID) - } - - resp = &pluginAPI.RevokeAuthTokenResponse{} - - return resp, nil -} - -func (p *PluginApplicationService) CopyPlugin(ctx context.Context, req *dto.CopyPluginRequest) (resp *dto.CopyPluginResponse, err error) { - res, err := p.DomainSVC.CopyPlugin(ctx, &dto.CopyPluginRequest{ - UserID: req.UserID, - PluginID: req.PluginID, - CopyScene: req.CopyScene, - TargetAPPID: req.TargetAPPID, - }) - if err != nil { - return nil, errorx.Wrapf(err, "CopyPlugin failed, pluginID=%d", req.PluginID) - } - - plugin := res.Plugin - - now := time.Now().UnixMilli() - resDoc := &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResSubType: ptr.Of(int32(plugin.PluginType)), - ResID: plugin.ID, - Name: ptr.Of(plugin.GetName()), - SpaceID: &plugin.SpaceID, - APPID: plugin.APPID, - OwnerID: &req.UserID, - PublishStatus: ptr.Of(resCommon.PublishStatus_UnPublished), - CreateTimeMS: ptr.Of(now), - } - if plugin.Published() { - resDoc.PublishStatus = ptr.Of(resCommon.PublishStatus_Published) - resDoc.PublishTimeMS = ptr.Of(now) - } - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Created, - Resource: resDoc, - }) - if err != nil { - return nil, errorx.Wrapf(err, "publish resource '%d' failed", plugin.ID) - } - - resp = &dto.CopyPluginResponse{ - Plugin: res.Plugin, - Tools: res.Tools, - } - - return resp, nil -} - -func (p *PluginApplicationService) MoveAPPPluginToLibrary(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error) { - plugin, err = p.DomainSVC.MoveAPPPluginToLibrary(ctx, pluginID) - if err != nil { - return nil, errorx.Wrapf(err, "MoveAPPPluginToLibrary failed, pluginID=%d", pluginID) - } - - now := time.Now().UnixMilli() - - err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ - OpType: searchEntity.Updated, - Resource: &searchEntity.ResourceDocument{ - ResType: resCommon.ResType_Plugin, - ResID: pluginID, - APPID: ptr.Of(int64(0)), - PublishStatus: ptr.Of(resCommon.PublishStatus_Published), - PublishTimeMS: ptr.Of(now), - UpdateTimeMS: ptr.Of(now), - }, - }) - if err != nil { - return nil, errorx.Wrapf(err, "publish resource '%d' failed", pluginID) - } - - return plugin, nil -} - func (p *PluginApplicationService) validateDraftPluginAccess(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error) { uid := ctxutil.GetUIDFromCtx(ctx) if uid == nil { @@ -1657,69 +314,3 @@ func (p *PluginApplicationService) validateDraftPluginAccess(ctx context.Context return plugin, nil } - -func (p *PluginApplicationService) OauthAuthorizationCode(ctx context.Context, req *botOpenAPI.OauthAuthorizationCodeReq) (resp *botOpenAPI.OauthAuthorizationCodeResp, err error) { - stateStr, err := url.QueryUnescape(req.State) - if err != nil { - return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) - } - - secret := os.Getenv(encrypt.StateSecretEnv) - if secret == "" { - secret = encrypt.DefaultStateSecret - } - - stateBytes, err := encrypt.DecryptByAES(stateStr, secret) - if err != nil { - return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) - } - - state := &entity.OAuthState{} - err = json.Unmarshal(stateBytes, state) - if err != nil { - return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) - } - - err = p.DomainSVC.OAuthCode(ctx, req.Code, state) - if err != nil { - return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "authorize failed")) - } - - resp = &botOpenAPI.OauthAuthorizationCodeResp{} - - return resp, nil -} - -func (p *PluginApplicationService) GetQueriedOAuthPluginList(ctx context.Context, req *pluginAPI.GetQueriedOAuthPluginListRequest) (resp *pluginAPI.GetQueriedOAuthPluginListResponse, err error) { - userID := ctxutil.GetUIDFromCtx(ctx) - if userID == nil { - return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) - } - - status, err := p.DomainSVC.GetAgentPluginsOAuthStatus(ctx, *userID, req.BotID) - if err != nil { - return nil, errorx.Wrapf(err, "GetAgentPluginsOAuthStatus failed, userID=%d, agentID=%d", *userID, req.BotID) - } - - if len(status) == 0 { - return &pluginAPI.GetQueriedOAuthPluginListResponse{ - OauthPluginList: []*pluginAPI.OAuthPluginInfo{}, - }, nil - } - - oauthPluginList := make([]*pluginAPI.OAuthPluginInfo, 0, len(status)) - for _, s := range status { - oauthPluginList = append(oauthPluginList, &pluginAPI.OAuthPluginInfo{ - PluginID: s.PluginID, - Status: s.Status, - Name: s.PluginName, - PluginIcon: s.PluginIconURL, - }) - } - - resp = &pluginAPI.GetQueriedOAuthPluginListResponse{ - OauthPluginList: oauthPluginList, - } - - return resp, nil -} diff --git a/backend/application/plugin/registration.go b/backend/application/plugin/registration.go new file mode 100644 index 000000000..a44bbb4ad --- /dev/null +++ b/backend/application/plugin/registration.go @@ -0,0 +1,216 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plugin + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/bytedance/sonic" + "github.com/getkin/kin-openapi/openapi3" + "gopkg.in/yaml.v3" + + pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" + "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + commonConsts "github.com/coze-dev/coze-studio/backend/types/consts" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +func (p *PluginApplicationService) RegisterPluginMeta(ctx context.Context, req *pluginAPI.RegisterPluginMetaRequest) (resp *pluginAPI.RegisterPluginMetaResponse, err error) { + userID := ctxutil.GetUIDFromCtx(ctx) + if userID == nil { + return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) + } + + _authType, ok := convert.ToAuthType(req.GetAuthType()) + if !ok { + return nil, fmt.Errorf("invalid auth type '%d'", req.GetAuthType()) + } + authType := ptr.Of(_authType) + + var authSubType *consts.AuthzSubType + if req.SubAuthType != nil { + _authSubType, ok := convert.ToAuthSubType(req.GetSubAuthType()) + if !ok { + return nil, fmt.Errorf("invalid sub authz type '%d'", req.GetSubAuthType()) + } + authSubType = ptr.Of(_authSubType) + } + + var loc consts.HTTPParamLocation + if *authType == consts.AuthzTypeOfService { + if req.GetLocation() == common.AuthorizationServiceLocation_Query { + loc = consts.ParamInQuery + } else if req.GetLocation() == common.AuthorizationServiceLocation_Header { + loc = consts.ParamInHeader + } else { + return nil, fmt.Errorf("invalid location '%s'", req.GetLocation()) + } + } + + r := &dto.CreateDraftPluginRequest{ + PluginType: req.GetPluginType(), + SpaceID: req.GetSpaceID(), + DeveloperID: *userID, + IconURI: req.Icon.URI, + ProjectID: req.ProjectID, + Name: req.GetName(), + Desc: req.GetDesc(), + ServerURL: req.GetURL(), + CommonParams: req.CommonParams, + AuthInfo: &dto.PluginAuthInfo{ + AuthzType: authType, + Location: ptr.Of(loc), + Key: req.Key, + ServiceToken: req.ServiceToken, + OAuthInfo: req.OauthInfo, + AuthzSubType: authSubType, + AuthzPayload: req.AuthPayload, + }, + } + pluginID, err := p.DomainSVC.CreateDraftPlugin(ctx, r) + if err != nil { + return nil, errorx.Wrapf(err, "CreateDraftPlugin failed") + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Created, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResSubType: ptr.Of(int32(req.GetPluginType())), + ResID: pluginID, + Name: &req.Name, + SpaceID: &req.SpaceID, + APPID: req.ProjectID, + OwnerID: userID, + PublishStatus: ptr.Of(resCommon.PublishStatus_UnPublished), + CreateTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + return nil, fmt.Errorf("publish resource '%d' failed, err=%v", pluginID, err) + } + + resp = &pluginAPI.RegisterPluginMetaResponse{ + PluginID: pluginID, + } + + return resp, nil +} + +func (p *PluginApplicationService) RegisterPlugin(ctx context.Context, req *pluginAPI.RegisterPluginRequest) (resp *pluginAPI.RegisterPluginResponse, err error) { + userID := ctxutil.GetUIDFromCtx(ctx) + if userID == nil { + return nil, errorx.New(errno.ErrPluginPermissionCode, errorx.KV(errno.PluginMsgKey, "session is required")) + } + + mf := &model.PluginManifest{} + err = sonic.UnmarshalString(req.AiPlugin, &mf) + if err != nil { + return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, err.Error())) + } + + mf.LogoURL = commonConsts.DefaultPluginIcon + + doc, err := openapi3.NewLoader().LoadFromData([]byte(req.Openapi)) + if err != nil { + return nil, errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error())) + } + + res, err := p.DomainSVC.CreateDraftPluginWithCode(ctx, &dto.CreateDraftPluginWithCodeRequest{ + SpaceID: req.GetSpaceID(), + DeveloperID: *userID, + ProjectID: req.ProjectID, + Manifest: mf, + OpenapiDoc: ptr.Of(model.Openapi3T(*doc)), + }) + if err != nil { + return nil, errorx.Wrapf(err, "CreateDraftPluginWithCode failed") + } + + err = p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{ + OpType: searchEntity.Created, + Resource: &searchEntity.ResourceDocument{ + ResType: resCommon.ResType_Plugin, + ResSubType: ptr.Of(int32(res.Plugin.PluginType)), + ResID: res.Plugin.ID, + Name: ptr.Of(res.Plugin.GetName()), + APPID: req.ProjectID, + SpaceID: &req.SpaceID, + OwnerID: userID, + PublishStatus: ptr.Of(resCommon.PublishStatus_UnPublished), + CreateTimeMS: ptr.Of(time.Now().UnixMilli()), + }, + }) + if err != nil { + return nil, fmt.Errorf("publish resource '%d' failed, err=%v", res.Plugin.ID, err) + } + + resp = &pluginAPI.RegisterPluginResponse{ + Data: &common.RegisterPluginData{ + PluginID: res.Plugin.ID, + Openapi: req.Openapi, + }, + } + + return resp, nil +} + +func (p *PluginApplicationService) Convert2OpenAPI(ctx context.Context, req *pluginAPI.Convert2OpenAPIRequest) (resp *pluginAPI.Convert2OpenAPIResponse, err error) { + res := p.DomainSVC.ConvertToOpenapi3Doc(ctx, &dto.ConvertToOpenapi3DocRequest{ + RawInput: req.Data, + PluginServerURL: req.PluginURL, + }) + + if res.ErrMsg != "" { + return &pluginAPI.Convert2OpenAPIResponse{ + Code: errno.ErrPluginInvalidThirdPartyCode, + Msg: res.ErrMsg, + DuplicateAPIInfos: []*common.DuplicateAPIInfo{}, + PluginDataFormat: ptr.Of(res.Format), + }, nil + } + + doc, err := yaml.Marshal(res.OpenapiDoc) + if err != nil { + return nil, fmt.Errorf("marshal openapi doc failed, err=%v", err) + } + mf, err := json.Marshal(res.Manifest) + if err != nil { + return nil, fmt.Errorf("marshal manifest failed, err=%v", err) + } + + resp = &pluginAPI.Convert2OpenAPIResponse{ + PluginDataFormat: ptr.Of(res.Format), + Openapi: ptr.Of(string(doc)), + AiPlugin: ptr.Of(string(mf)), + DuplicateAPIInfos: []*common.DuplicateAPIInfo{}, + } + + return resp, nil +} diff --git a/backend/application/singleagent/duplicate.go b/backend/application/singleagent/duplicate.go index fccd63471..ad1fa6c0f 100644 --- a/backend/application/singleagent/duplicate.go +++ b/backend/application/singleagent/duplicate.go @@ -23,7 +23,6 @@ import ( intelligence "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/common" "github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory" "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" - crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity" @@ -138,8 +137,8 @@ func duplicateVariables(ctx context.Context, appContext *ServiceComponents, oldA return newAgent, nil } -func duplicatePlugin(ctx context.Context, _ *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) { - err := crossplugin.DefaultSVC().DuplicateDraftAgentTools(ctx, oldAgent.AgentID, newAgent.AgentID) +func duplicatePlugin(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) { + err := appContext.PluginDomainSVC.DuplicateDraftAgentTools(ctx, oldAgent.AgentID, newAgent.AgentID) if err != nil { return nil, err } diff --git a/backend/application/singleagent/get.go b/backend/application/singleagent/get.go index 12724dad0..30aa92594 100644 --- a/backend/application/singleagent/get.go +++ b/backend/application/singleagent/get.go @@ -27,9 +27,9 @@ import ( workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" "github.com/coze-dev/coze-studio/backend/api/model/playground" "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" "github.com/coze-dev/coze-studio/backend/api/model/workflow" - plugindto "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service" pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" @@ -200,12 +200,12 @@ func (s *SingleAgentApplicationService) fetchKnowledgeDetails(ctx context.Contex } func (s *SingleAgentApplicationService) fetchToolDetails(ctx context.Context, agentInfo *entity.SingleAgent, req *playground.GetDraftBotInfoAgwRequest) ([]*pluginEntity.ToolInfo, error) { - return s.appContext.PluginDomainSVC.MGetAgentTools(ctx, &plugindto.MGetAgentToolsRequest{ + return s.appContext.PluginDomainSVC.MGetAgentTools(ctx, &model.MGetAgentToolsRequest{ SpaceID: agentInfo.SpaceID, AgentID: req.GetBotID(), IsDraft: true, - VersionAgentTools: slices.Transform(agentInfo.Plugin, func(a *bot_common.PluginInfo) pluginEntity.VersionAgentTool { - return pluginEntity.VersionAgentTool{ + VersionAgentTools: slices.Transform(agentInfo.Plugin, func(a *bot_common.PluginInfo) model.VersionAgentTool { + return model.VersionAgentTool{ ToolID: a.GetApiId(), } }), @@ -213,7 +213,7 @@ func (s *SingleAgentApplicationService) fetchToolDetails(ctx context.Context, ag } func (s *SingleAgentApplicationService) fetchPluginDetails(ctx context.Context, agentInfo *entity.SingleAgent, toolInfos []*pluginEntity.ToolInfo) ([]*pluginEntity.PluginInfo, error) { - vPlugins := make([]pluginEntity.VersionPlugin, 0, len(agentInfo.Plugin)) + vPlugins := make([]model.VersionPlugin, 0, len(agentInfo.Plugin)) vPluginMap := make(map[string]bool, len(agentInfo.Plugin)) for _, v := range toolInfos { k := fmt.Sprintf("%d:%s", v.PluginID, v.GetVersion()) @@ -221,7 +221,7 @@ func (s *SingleAgentApplicationService) fetchPluginDetails(ctx context.Context, continue } vPluginMap[k] = true - vPlugins = append(vPlugins, pluginEntity.VersionPlugin{ + vPlugins = append(vPlugins, model.VersionPlugin{ PluginID: v.PluginID, Version: v.GetVersion(), }) @@ -330,7 +330,7 @@ func (s *SingleAgentApplicationService) pluginInfoDo2Vo(ctx context.Context, plu }) } -func parametersDo2Vo(op *plugindto.Openapi3Operation) []*playground.PluginParameter { +func parametersDo2Vo(op *model.Openapi3Operation) []*playground.PluginParameter { var convertReqBody func(paramName string, isRequired bool, sc *openapi3.Schema) *playground.PluginParameter convertReqBody = func(paramName string, isRequired bool, sc *openapi3.Schema) *playground.PluginParameter { if disabledParam(sc) { @@ -338,7 +338,7 @@ func parametersDo2Vo(op *plugindto.Openapi3Operation) []*playground.PluginParame } var assistType *int64 - if v, ok := sc.Extensions[plugindto.APISchemaExtendAssistType]; ok { + if v, ok := sc.Extensions[consts.APISchemaExtendAssistType]; ok { if _v, ok := v.(string); ok { assistType = toParameterAssistType(_v) } @@ -409,7 +409,7 @@ func parametersDo2Vo(op *plugindto.Openapi3Operation) []*playground.PluginParame } var assistType *int64 - if v, ok := schemaVal.Extensions[plugindto.APISchemaExtendAssistType]; ok { + if v, ok := schemaVal.Extensions[consts.APISchemaExtendAssistType]; ok { if _v, ok := v.(string); ok { assistType = toParameterAssistType(_v) } @@ -455,27 +455,27 @@ func toParameterAssistType(assistType string) *int64 { if assistType == "" { return nil } - switch plugindto.APIFileAssistType(assistType) { - case plugindto.AssistTypeFile: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_CODE)) - case plugindto.AssistTypeImage: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_IMAGE)) - case plugindto.AssistTypeDoc: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_DOC)) - case plugindto.AssistTypePPT: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_PPT)) - case plugindto.AssistTypeCode: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_CODE)) - case plugindto.AssistTypeExcel: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_EXCEL)) - case plugindto.AssistTypeZIP: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_ZIP)) - case plugindto.AssistTypeVideo: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_VIDEO)) - case plugindto.AssistTypeAudio: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_AUDIO)) - case plugindto.AssistTypeTXT: - return ptr.Of(int64(plugin_develop_common.AssistParameterType_TXT)) + switch consts.APIFileAssistType(assistType) { + case consts.AssistTypeFile: + return ptr.Of(int64(common.AssistParameterType_CODE)) + case consts.AssistTypeImage: + return ptr.Of(int64(common.AssistParameterType_IMAGE)) + case consts.AssistTypeDoc: + return ptr.Of(int64(common.AssistParameterType_DOC)) + case consts.AssistTypePPT: + return ptr.Of(int64(common.AssistParameterType_PPT)) + case consts.AssistTypeCode: + return ptr.Of(int64(common.AssistParameterType_CODE)) + case consts.AssistTypeExcel: + return ptr.Of(int64(common.AssistParameterType_EXCEL)) + case consts.AssistTypeZIP: + return ptr.Of(int64(common.AssistParameterType_ZIP)) + case consts.AssistTypeVideo: + return ptr.Of(int64(common.AssistParameterType_VIDEO)) + case consts.AssistTypeAudio: + return ptr.Of(int64(common.AssistParameterType_AUDIO)) + case consts.AssistTypeTXT: + return ptr.Of(int64(common.AssistParameterType_TXT)) default: return nil } diff --git a/backend/application/singleagent/publish.go b/backend/application/singleagent/publish.go index bc51113fb..0f0fd0393 100644 --- a/backend/application/singleagent/publish.go +++ b/backend/application/singleagent/publish.go @@ -235,7 +235,7 @@ func publishAgentPlugins(ctx context.Context, appContext *ServiceComponents, pub func publishShortcutCommand(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) { logs.CtxInfof(ctx, "publishShortcutCommand agentID: %d, shortcutCommand: %v", agent.AgentID, agent.ShortcutCommand) - if agent.ShortcutCommand == nil || len(agent.ShortcutCommand) == 0 { + if len(agent.ShortcutCommand) == 0 { return agent, nil } cmdIDs := slices.Transform(agent.ShortcutCommand, func(a string) int64 { diff --git a/backend/application/singleagent/single_agent.go b/backend/application/singleagent/single_agent.go index 5ab364647..7f43b1bbd 100644 --- a/backend/application/singleagent/single_agent.go +++ b/backend/application/singleagent/single_agent.go @@ -22,10 +22,6 @@ import ( "strconv" "time" - shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service" - "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" - "github.com/coze-dev/coze-studio/backend/types/consts" - "github.com/bytedance/sonic" "github.com/getkin/kin-openapi/openapi3" @@ -39,31 +35,33 @@ import ( "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginConsts "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service" variableEntity "github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity" - shortcutEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity" - searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" + shortcutEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity" + shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/errno" ) type SingleAgentApplicationService struct { appContext *ServiceComponents DomainSVC singleagent.SingleAgent - ShortcutCMDSVC shortcutCmd.ShortcutCmd + ShortcutCmdSvc shortcutCmd.ShortcutCmd } func newApplicationService(s *ServiceComponents, domain singleagent.SingleAgent) *SingleAgentApplicationService { return &SingleAgentApplicationService{ appContext: s, DomainSVC: domain, - ShortcutCMDSVC: s.ShortcutCMDDomainSVC, + ShortcutCmdSvc: s.ShortcutCMDDomainSVC, } } @@ -158,7 +156,7 @@ func (s *SingleAgentApplicationService) UpdatePromptDisable(ctx context.Context, } if len(draft.Database) == 0 { - return nil, fmt.Errorf("agent %d has no database", agentID) // TODO (@fanlv): error code + return nil, fmt.Errorf("agent %d has no database", agentID) } dbInfos := draft.Database @@ -172,7 +170,7 @@ func (s *SingleAgentApplicationService) UpdatePromptDisable(ctx context.Context, } if !found { - return nil, fmt.Errorf("database %d not found in agent %d", req.GetDatabaseID(), agentID) // TODO (@fanlv): error code + return nil, fmt.Errorf("database %d not found in agent %d", req.GetDatabaseID(), agentID) } draft.Database = dbInfos @@ -355,7 +353,7 @@ func (s *SingleAgentApplicationService) applyAgentUpdates(target *entity.SingleA target.BackgroundImageInfoList = patch.BackgroundImageInfoList } - if patch.Agents != nil && len(patch.Agents) > 0 && patch.Agents[0].JumpConfig != nil { + if len(patch.Agents) > 0 && patch.Agents[0].JumpConfig != nil { target.JumpConfig = patch.Agents[0].JumpConfig } @@ -469,10 +467,10 @@ func disabledParam(schemaVal *openapi3.Schema) bool { return false } globalDisable, localDisable := false, false - if v, ok := schemaVal.Extensions[plugin.APISchemaExtendLocalDisable]; ok { + if v, ok := schemaVal.Extensions[pluginConsts.APISchemaExtendLocalDisable]; ok { localDisable = v.(bool) } - if v, ok := schemaVal.Extensions[plugin.APISchemaExtendGlobalDisable]; ok { + if v, ok := schemaVal.Extensions[pluginConsts.APISchemaExtendGlobalDisable]; ok { globalDisable = v.(bool) } return globalDisable || localDisable @@ -614,7 +612,7 @@ func (s *SingleAgentApplicationService) ListAgentPublishHistory(ctx context.Cont Name: creator.Name, AvatarURL: creator.IconURL, Self: uid == v.CreatorID, - // UserUniqueName: creator. UserUniqueName,//TODO (@fanlv): Change the user domain after it is completed + // UserUniqueName: creator. UserUniqueName, // UserLabel TODO }, PublishID: &v.PublishID, @@ -701,7 +699,7 @@ func (s *SingleAgentApplicationService) getAgentInfo(ctx context.Context, botID } if len(agentInfo.ShortcutCommand) > 0 { - shortcutInfos, err := s.ShortcutCMDSVC.ListCMD(ctx, &shortcutEntity.ListMeta{ + shortcutInfos, err := s.ShortcutCmdSvc.ListCMD(ctx, &shortcutEntity.ListMeta{ ObjectID: agentInfo.AgentID, IsOnline: 1, CommandIDs: slices.Transform(agentInfo.ShortcutCommand, func(s string) int64 { diff --git a/backend/application/workflow/init.go b/backend/application/workflow/init.go index b6107dbd2..650aaf0b7 100644 --- a/backend/application/workflow/init.go +++ b/backend/application/workflow/init.go @@ -34,6 +34,8 @@ import ( dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service" variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service" plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service" + wrapPlugin "github.com/coze-dev/coze-studio/backend/domain/workflow/plugin" + search "github.com/coze-dev/coze-studio/backend/domain/search/service" "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/config" @@ -97,6 +99,7 @@ func InitService(_ context.Context, components *ServiceComponents) (*Application workflow.SetRepository(workflowRepo) workflowDomainSVC := service.NewWorkflowService(workflowRepo) + wrapPlugin.SetOSS(components.Tos) code.SetCodeRunner(components.CodeRunner) callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler()) diff --git a/backend/application/workflow/workflow.go b/backend/application/workflow/workflow.go index 47f822e87..de0c4c3e9 100644 --- a/backend/application/workflow/workflow.go +++ b/backend/application/workflow/workflow.go @@ -45,14 +45,14 @@ import ( appplugin "github.com/coze-dev/coze-studio/backend/application/plugin" "github.com/coze-dev/coze-studio/backend/application/user" crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge" - crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - plugindto "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginConsts "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" search "github.com/coze-dev/coze-studio/backend/domain/search/entity" domainWorkflow "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/plugin" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/infra/contract/imagex" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" @@ -976,7 +976,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w return 0, nil, err } - pluginMap := make(map[int64]*plugindto.PluginEntity) + pluginMap := make(map[int64]*vo.PluginEntity) pluginToolMap := make(map[int64]int64) if len(ds.PluginIDs) > 0 { @@ -985,13 +985,13 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w response, err := appplugin.PluginApplicationSVC.CopyPlugin(ctx, &dto.CopyPluginRequest{ PluginID: id, UserID: ctxutil.MustGetUIDFromCtx(ctx), - CopyScene: plugindto.CopySceneOfToLibrary, + CopyScene: pluginConsts.CopySceneOfToLibrary, }) if err != nil { return 0, nil, err } pInfo := response.Plugin - pluginMap[id] = &plugindto.PluginEntity{ + pluginMap[id] = &vo.PluginEntity{ PluginID: pInfo.ID, PluginVersion: pInfo.Version, } @@ -1109,9 +1109,9 @@ func (w *ApplicationService) DuplicateWorkflowsByAppID(ctx context.Context, sour } }() - pluginMap := make(map[int64]*plugindto.PluginEntity) + pluginMap := make(map[int64]*vo.PluginEntity) for o, n := range externalResource.PluginMap { - pluginMap[o] = &plugindto.PluginEntity{ + pluginMap[o] = &vo.PluginEntity{ PluginID: n, } } @@ -1209,7 +1209,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w return 0, nil, err } - pluginMap := make(map[int64]*plugindto.PluginEntity) + pluginMap := make(map[int64]*vo.PluginEntity) if len(ds.PluginIDs) > 0 { for idx := range ds.PluginIDs { id := ds.PluginIDs[idx] @@ -1217,7 +1217,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w if err != nil { return 0, nil, err } - pluginMap[id] = &plugindto.PluginEntity{ + pluginMap[id] = &vo.PluginEntity{ PluginID: pInfo.ID, PluginVersion: pInfo.Version, } @@ -2650,8 +2650,8 @@ func (w *ApplicationService) GetApiDetail(ctx context.Context, req *workflow.Get return nil, err } - toolInfoResponse, err := crossplugin.DefaultSVC().GetPluginToolsInfo(ctx, &plugindto.ToolsInfoRequest{ - PluginEntity: plugindto.PluginEntity{ + toolInfoResponse, err := plugin.GetPluginToolsInfo(ctx, &plugin.ToolsInfoRequest{ + PluginEntity: vo.PluginEntity{ PluginID: pluginID, PluginVersion: req.PluginVersion, }, @@ -2717,8 +2717,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req } var ( - pluginSvc = crossplugin.DefaultSVC() - pluginToolsInfoReqs = make(map[int64]*plugindto.ToolsInfoRequest) + pluginToolsInfoReqs = make(map[int64]*plugin.ToolsInfoRequest) pluginDetailMap = make(map[string]*workflow.PluginDetail) toolsDetailInfo = make(map[string]*workflow.APIDetail) workflowDetailMap = make(map[string]*workflow.WorkflowDetail) @@ -2740,8 +2739,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req if r, ok := pluginToolsInfoReqs[pluginID]; ok { r.ToolIDs = append(r.ToolIDs, toolID) } else { - pluginToolsInfoReqs[pluginID] = &plugindto.ToolsInfoRequest{ - PluginEntity: plugindto.PluginEntity{ + pluginToolsInfoReqs[pluginID] = &plugin.ToolsInfoRequest{ + PluginEntity: vo.PluginEntity{ PluginID: pluginID, PluginVersion: pl.PluginVersion, }, @@ -2752,7 +2751,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req } for _, r := range pluginToolsInfoReqs { - resp, err := pluginSvc.GetPluginToolsInfo(ctx, r) + resp, err := plugin.GetPluginToolsInfo(ctx, r) if err != nil { return nil, err } @@ -2925,7 +2924,6 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req var fcPluginSetting *workflow.FCPluginSetting if req.GetPluginFcSetting() != nil { var ( - pluginSvc = crossplugin.DefaultSVC() pluginFcSetting = req.GetPluginFcSetting() isDraft = pluginFcSetting.GetIsDraft() ) @@ -2940,15 +2938,15 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req return nil, err } - pluginReq := &plugindto.ToolsInfoRequest{ - PluginEntity: plugindto.PluginEntity{ + pluginReq := &plugin.ToolsInfoRequest{ + PluginEntity: vo.PluginEntity{ PluginID: pluginID, }, ToolIDs: []int64{toolID}, IsDraft: isDraft, } - pInfo, err := pluginSvc.GetPluginToolsInfo(ctx, pluginReq) + pInfo, err := plugin.GetPluginToolsInfo(ctx, pluginReq) if err != nil { return nil, err } diff --git a/backend/crossdomain/contract/plugin/dto/consts.go b/backend/crossdomain/contract/plugin/consts/consts.go similarity index 92% rename from backend/crossdomain/contract/plugin/dto/consts.go rename to backend/crossdomain/contract/plugin/consts/consts.go index dd1fa7f75..658add7df 100644 --- a/backend/crossdomain/contract/plugin/dto/consts.go +++ b/backend/crossdomain/contract/plugin/consts/consts.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dto +package consts import "github.com/getkin/kin-openapi/openapi3" @@ -127,3 +127,12 @@ type InterruptEventType string const ( InterruptEventTypeOfToolNeedOAuth InterruptEventType = "tool_need_oauth" ) + +// MIME Type +const ( + MediaTypeJson = "application/json" + MediaTypeProblemJson = "application/problem+json" + MediaTypeFormURLEncoded = "application/x-www-form-urlencoded" + MediaTypeXYaml = "application/x-yaml" + MediaTypeYaml = "application/yaml" +) diff --git a/backend/application/base/pluginutil/api.go b/backend/crossdomain/contract/plugin/convert/api/api.go similarity index 83% rename from backend/application/base/pluginutil/api.go rename to backend/crossdomain/contract/plugin/convert/api/api.go index d39152d9d..fb0f344e2 100644 --- a/backend/application/base/pluginutil/api.go +++ b/backend/crossdomain/contract/plugin/convert/api/api.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package pluginutil +package api import ( "net/http" @@ -22,10 +22,10 @@ import ( "github.com/getkin/kin-openapi/openapi3" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" - common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -56,7 +56,7 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( var mType *openapi3.MediaType if hasSetReqBody { - mType = op.RequestBody.Value.Content[plugin.MediaTypeJson] + mType = op.RequestBody.Value.Content[consts.MediaTypeJson] } else { hasSetReqBody = true mType = &openapi3.MediaType{ @@ -70,7 +70,7 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( op.RequestBody = &openapi3.RequestBodyRef{ Value: &openapi3.RequestBody{ Content: map[string]*openapi3.MediaType{ - plugin.MediaTypeJson: mType, + consts.MediaTypeJson: mType, }, }, } @@ -94,7 +94,7 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( op.Parameters = []*openapi3.ParameterRef{} } if !hasSetReqBody { - op.RequestBody = entity.DefaultOpenapi3RequestBody() + op.RequestBody = model.DefaultOpenapi3RequestBody() } } @@ -107,7 +107,7 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( strconv.Itoa(http.StatusOK): { Value: &openapi3.Response{ Content: map[string]*openapi3.MediaType{ - plugin.MediaTypeJson: { + consts.MediaTypeJson: { Schema: &openapi3.SchemaRef{ Value: &openapi3.Schema{ Type: openapi3.TypeObject, @@ -127,7 +127,7 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( } resp, _ := op.Responses[strconv.Itoa(http.StatusOK)] - mType, _ := resp.Value.Content[plugin.MediaTypeJson] // only support application/json + mType, _ := resp.Value.Content[consts.MediaTypeJson] // only support application/json mType.Schema.Value.Properties[apiParam.Name] = &openapi3.SchemaRef{ Value: _apiParam, } @@ -138,14 +138,14 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( } if respParams != nil && !hasSetRespBody { - op.Responses = entity.DefaultOpenapi3Responses() + op.Responses = model.DefaultOpenapi3Responses() } return op, nil } func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, error) { - paramType, ok := plugin.ToOpenapiParamType(apiParam.Type) + paramType, ok := convert.ToOpenapiParamType(apiParam.Type) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the type '%s' of field '%s' is invalid", apiParam.Type, apiParam.Name)) @@ -160,7 +160,7 @@ func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, err Type: paramType, Default: apiParam.GlobalDefault, Extensions: map[string]interface{}{ - plugin.APISchemaExtendGlobalDisable: apiParam.GlobalDisable, + consts.APISchemaExtendGlobalDisable: apiParam.GlobalDisable, }, } @@ -175,7 +175,7 @@ func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, err } arrayItem := apiParam.SubParameters[0] - arrayItemType, ok := plugin.ToOpenapiParamType(arrayItem.Type) + arrayItemType, ok := convert.ToOpenapiParamType(arrayItem.Type) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the item type '%s' of field '%s' is invalid", arrayItemType, apiParam.Name)) @@ -193,13 +193,13 @@ func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, err } if arrayItem.GetAssistType() > 0 { - aType, ok := plugin.ToAPIAssistType(arrayItem.GetAssistType()) + aType, ok := convert.ToAPIAssistType(arrayItem.GetAssistType()) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", arrayItem.GetAssistType(), apiParam.Name)) } - itemSchema.Extensions[plugin.APISchemaExtendAssistType] = aType - format, ok := plugin.AssistTypeToFormat(aType) + itemSchema.Extensions[consts.APISchemaExtendAssistType] = aType + format, ok := convert.AssistTypeToFormat(aType) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name)) @@ -216,20 +216,20 @@ func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, err paramSchema.Default = apiParam.LocalDefault } if apiParam.LocalDisable { - paramSchema.Extensions[plugin.APISchemaExtendLocalDisable] = true + paramSchema.Extensions[consts.APISchemaExtendLocalDisable] = true } if apiParam.VariableRef != nil && *apiParam.VariableRef != "" { - paramSchema.Extensions[plugin.APISchemaExtendVariableRef] = apiParam.VariableRef + paramSchema.Extensions[consts.APISchemaExtendVariableRef] = apiParam.VariableRef } if apiParam.GetAssistType() > 0 { - aType, ok := plugin.ToAPIAssistType(apiParam.GetAssistType()) + aType, ok := convert.ToAPIAssistType(apiParam.GetAssistType()) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", apiParam.GetAssistType(), apiParam.Name)) } - paramSchema.Extensions[plugin.APISchemaExtendAssistType] = aType - format, ok := plugin.AssistTypeToFormat(aType) + paramSchema.Extensions[consts.APISchemaExtendAssistType] = aType + format, ok := convert.AssistTypeToFormat(aType) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name)) @@ -237,7 +237,7 @@ func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, err paramSchema.Format = format } - loc, ok := plugin.ToHTTPParamLocation(apiParam.Location) + loc, ok := convert.ToHTTPParamLocation(apiParam.Location) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the location '%s' of field '%s' is invalid ", apiParam.Location, apiParam.Name)) @@ -257,7 +257,7 @@ func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, err } func toOpenapi3Schema(apiParam *common.APIParameter) (*openapi3.Schema, error) { - paramType, ok := plugin.ToOpenapiParamType(apiParam.Type) + paramType, ok := convert.ToOpenapiParamType(apiParam.Type) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the type '%s' of field '%s' is invalid", apiParam.Type, apiParam.Name)) @@ -268,27 +268,27 @@ func toOpenapi3Schema(apiParam *common.APIParameter) (*openapi3.Schema, error) { Type: paramType, Default: apiParam.GlobalDefault, Extensions: map[string]interface{}{ - plugin.APISchemaExtendGlobalDisable: apiParam.GlobalDisable, + consts.APISchemaExtendGlobalDisable: apiParam.GlobalDisable, }, } if apiParam.LocalDefault != nil && *apiParam.LocalDefault != "" { sc.Default = apiParam.LocalDefault } if apiParam.LocalDisable { - sc.Extensions[plugin.APISchemaExtendLocalDisable] = true + sc.Extensions[consts.APISchemaExtendLocalDisable] = true } if apiParam.VariableRef != nil && *apiParam.VariableRef != "" { - sc.Extensions[plugin.APISchemaExtendVariableRef] = apiParam.VariableRef + sc.Extensions[consts.APISchemaExtendVariableRef] = apiParam.VariableRef } if apiParam.GetAssistType() > 0 { - aType, ok := plugin.ToAPIAssistType(apiParam.GetAssistType()) + aType, ok := convert.ToAPIAssistType(apiParam.GetAssistType()) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", apiParam.GetAssistType(), apiParam.Name)) } - sc.Extensions[plugin.APISchemaExtendAssistType] = aType - format, ok := plugin.AssistTypeToFormat(aType) + sc.Extensions[consts.APISchemaExtendAssistType] = aType + format, ok := convert.AssistTypeToFormat(aType) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name)) @@ -321,7 +321,7 @@ func toOpenapi3Schema(apiParam *common.APIParameter) (*openapi3.Schema, error) { } arrayItem := apiParam.SubParameters[0] - itemType, ok := plugin.ToOpenapiParamType(arrayItem.Type) + itemType, ok := convert.ToOpenapiParamType(arrayItem.Type) if !ok { return nil, errorx.New(errno.ErrPluginInvalidParamCode, errorx.KVf(errno.PluginMsgKey, "the item type '%s' of field '%s' is invalid", itemType, apiParam.Name)) diff --git a/backend/crossdomain/contract/plugin/convert/auth.go b/backend/crossdomain/contract/plugin/convert/auth.go new file mode 100644 index 000000000..b57eb5830 --- /dev/null +++ b/backend/crossdomain/contract/plugin/convert/auth.go @@ -0,0 +1,98 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package convert + +import ( + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" +) + +var authTypes = map[common.AuthorizationType]consts.AuthzType{ + common.AuthorizationType_None: consts.AuthzTypeOfNone, + common.AuthorizationType_Service: consts.AuthzTypeOfService, + common.AuthorizationType_OAuth: consts.AuthzTypeOfOAuth, + common.AuthorizationType_Standard: consts.AuthzTypeOfOAuth, // deprecated, the same as OAuth +} + +func ToAuthType(typ common.AuthorizationType) (consts.AuthzType, bool) { + _type, ok := authTypes[typ] + return _type, ok +} + +var thriftAuthTypes = func() map[consts.AuthzType]common.AuthorizationType { + types := make(map[consts.AuthzType]common.AuthorizationType, len(authTypes)) + for k, v := range authTypes { + if v == consts.AuthzTypeOfOAuth { + types[v] = common.AuthorizationType_OAuth + } else { + types[v] = k + } + } + return types +}() + +func ToThriftAuthType(typ consts.AuthzType) (common.AuthorizationType, bool) { + _type, ok := thriftAuthTypes[typ] + return _type, ok +} + +var subAuthTypes = map[int32]consts.AuthzSubType{ + int32(common.ServiceAuthSubType_ApiKey): consts.AuthzSubTypeOfServiceAPIToken, + int32(common.ServiceAuthSubType_OAuthAuthorizationCode): consts.AuthzSubTypeOfOAuthAuthorizationCode, +} + +func ToAuthSubType(typ int32) (consts.AuthzSubType, bool) { + _type, ok := subAuthTypes[typ] + return _type, ok +} + +var thriftSubAuthTypes = func() map[consts.AuthzSubType]int32 { + types := make(map[consts.AuthzSubType]int32, len(subAuthTypes)) + for k, v := range subAuthTypes { + types[v] = int32(k) + } + return types +}() + +func ToThriftAuthSubType(typ consts.AuthzSubType) (int32, bool) { + _type, ok := thriftSubAuthTypes[typ] + return _type, ok +} + +var apiAuthModes = map[common.PluginToolAuthType]consts.ToolAuthMode{ + common.PluginToolAuthType_Required: consts.ToolAuthModeOfRequired, + common.PluginToolAuthType_Supported: consts.ToolAuthModeOfSupported, + common.PluginToolAuthType_Disable: consts.ToolAuthModeOfDisabled, +} + +func ToAPIAuthMode(mode common.PluginToolAuthType) (consts.ToolAuthMode, bool) { + _mode, ok := apiAuthModes[mode] + return _mode, ok +} + +var thriftAPIAuthModes = func() map[consts.ToolAuthMode]common.PluginToolAuthType { + modes := make(map[consts.ToolAuthMode]common.PluginToolAuthType, len(apiAuthModes)) + for k, v := range apiAuthModes { + modes[v] = k + } + return modes +}() + +func ToThriftAPIAuthMode(mode consts.ToolAuthMode) (common.PluginToolAuthType, bool) { + _mode, ok := thriftAPIAuthModes[mode] + return _mode, ok +} diff --git a/backend/crossdomain/contract/plugin/convert/format.go b/backend/crossdomain/contract/plugin/convert/format.go new file mode 100644 index 000000000..41483613b --- /dev/null +++ b/backend/crossdomain/contract/plugin/convert/format.go @@ -0,0 +1,71 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package convert + +import ( + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" +) + +var assistTypeToFormat = map[consts.APIFileAssistType]string{ + consts.AssistTypeFile: "file_url", + consts.AssistTypeImage: "image_url", + consts.AssistTypeDoc: "doc_url", + consts.AssistTypePPT: "ppt_url", + consts.AssistTypeCode: "code_url", + consts.AssistTypeExcel: "excel_url", + consts.AssistTypeZIP: "zip_url", + consts.AssistTypeVideo: "video_url", + consts.AssistTypeAudio: "audio_url", + consts.AssistTypeTXT: "txt_url", +} + +func AssistTypeToFormat(typ consts.APIFileAssistType) (string, bool) { + format, ok := assistTypeToFormat[typ] + return format, ok +} + +var formatToAssistType = func() map[string]consts.APIFileAssistType { + types := make(map[string]consts.APIFileAssistType, len(assistTypeToFormat)) + for k, v := range assistTypeToFormat { + types[v] = k + } + return types +}() + +func FormatToAssistType(format string) (consts.APIFileAssistType, bool) { + typ, ok := formatToAssistType[format] + return typ, ok +} + +var assistTypeToThriftFormat = map[consts.APIFileAssistType]common.PluginParamTypeFormat{ + consts.AssistTypeFile: common.PluginParamTypeFormat_FileUrl, + consts.AssistTypeImage: common.PluginParamTypeFormat_ImageUrl, + consts.AssistTypeDoc: common.PluginParamTypeFormat_DocUrl, + consts.AssistTypePPT: common.PluginParamTypeFormat_PptUrl, + consts.AssistTypeCode: common.PluginParamTypeFormat_CodeUrl, + consts.AssistTypeExcel: common.PluginParamTypeFormat_ExcelUrl, + consts.AssistTypeZIP: common.PluginParamTypeFormat_ZipUrl, + consts.AssistTypeVideo: common.PluginParamTypeFormat_VideoUrl, + consts.AssistTypeAudio: common.PluginParamTypeFormat_AudioUrl, + consts.AssistTypeTXT: common.PluginParamTypeFormat_TxtUrl, +} + +func AssistTypeToThriftFormat(typ consts.APIFileAssistType) (common.PluginParamTypeFormat, bool) { + format, ok := assistTypeToThriftFormat[typ] + return format, ok +} diff --git a/backend/crossdomain/contract/plugin/convert/http.go b/backend/crossdomain/contract/plugin/convert/http.go new file mode 100644 index 000000000..f2dbc877b --- /dev/null +++ b/backend/crossdomain/contract/plugin/convert/http.go @@ -0,0 +1,49 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package convert + +import ( + "net/http" + + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" +) + +var httpMethods = map[common.APIMethod]string{ + common.APIMethod_GET: http.MethodGet, + common.APIMethod_POST: http.MethodPost, + common.APIMethod_PUT: http.MethodPut, + common.APIMethod_DELETE: http.MethodDelete, + common.APIMethod_PATCH: http.MethodPatch, +} + +var thriftAPIMethods = func() map[string]common.APIMethod { + methods := make(map[string]common.APIMethod, len(httpMethods)) + for k, v := range httpMethods { + methods[v] = k + } + return methods +}() + +func ToThriftAPIMethod(method string) (common.APIMethod, bool) { + _method, ok := thriftAPIMethods[method] + return _method, ok +} + +func ToHTTPMethod(method common.APIMethod) (string, bool) { + _method, ok := httpMethods[method] + return _method, ok +} diff --git a/backend/crossdomain/contract/plugin/convert/param.go b/backend/crossdomain/contract/plugin/convert/param.go new file mode 100644 index 000000000..fedc39dc0 --- /dev/null +++ b/backend/crossdomain/contract/plugin/convert/param.go @@ -0,0 +1,112 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package convert + +import ( + "github.com/getkin/kin-openapi/openapi3" + + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" +) + +var httpParamLocations = map[common.ParameterLocation]consts.HTTPParamLocation{ + common.ParameterLocation_Path: consts.ParamInPath, + common.ParameterLocation_Query: consts.ParamInQuery, + common.ParameterLocation_Body: consts.ParamInBody, + common.ParameterLocation_Header: consts.ParamInHeader, +} + +func ToHTTPParamLocation(loc common.ParameterLocation) (consts.HTTPParamLocation, bool) { + _loc, ok := httpParamLocations[loc] + return _loc, ok +} + +var thriftHTTPParamLocations = func() map[consts.HTTPParamLocation]common.ParameterLocation { + locations := make(map[consts.HTTPParamLocation]common.ParameterLocation, len(httpParamLocations)) + for k, v := range httpParamLocations { + locations[v] = k + } + return locations +}() + +func ToThriftHTTPParamLocation(loc consts.HTTPParamLocation) (common.ParameterLocation, bool) { + _loc, ok := thriftHTTPParamLocations[loc] + return _loc, ok +} + +var openapiTypes = map[common.ParameterType]string{ + common.ParameterType_String: openapi3.TypeString, + common.ParameterType_Integer: openapi3.TypeInteger, + common.ParameterType_Number: openapi3.TypeNumber, + common.ParameterType_Object: openapi3.TypeObject, + common.ParameterType_Array: openapi3.TypeArray, + common.ParameterType_Bool: openapi3.TypeBoolean, +} + +func ToOpenapiParamType(typ common.ParameterType) (string, bool) { + _typ, ok := openapiTypes[typ] + return _typ, ok +} + +var thriftParameterTypes = func() map[string]common.ParameterType { + types := make(map[string]common.ParameterType, len(openapiTypes)) + for k, v := range openapiTypes { + types[v] = k + } + return types +}() + +func ToThriftParamType(typ string) (common.ParameterType, bool) { + _typ, ok := thriftParameterTypes[typ] + return _typ, ok +} + +var apiAssistTypes = map[common.AssistParameterType]consts.APIFileAssistType{ + common.AssistParameterType_DEFAULT: consts.AssistTypeFile, + common.AssistParameterType_IMAGE: consts.AssistTypeImage, + common.AssistParameterType_DOC: consts.AssistTypeDoc, + common.AssistParameterType_PPT: consts.AssistTypePPT, + common.AssistParameterType_CODE: consts.AssistTypeCode, + common.AssistParameterType_EXCEL: consts.AssistTypeExcel, + common.AssistParameterType_ZIP: consts.AssistTypeZIP, + common.AssistParameterType_VIDEO: consts.AssistTypeVideo, + common.AssistParameterType_AUDIO: consts.AssistTypeAudio, + common.AssistParameterType_TXT: consts.AssistTypeTXT, +} + +func ToAPIAssistType(typ common.AssistParameterType) (consts.APIFileAssistType, bool) { + _typ, ok := apiAssistTypes[typ] + return _typ, ok +} + +var thriftAPIAssistTypes = func() map[consts.APIFileAssistType]common.AssistParameterType { + types := make(map[consts.APIFileAssistType]common.AssistParameterType, len(apiAssistTypes)) + for k, v := range apiAssistTypes { + types[v] = k + } + return types +}() + +func ToThriftAPIAssistType(typ consts.APIFileAssistType) (common.AssistParameterType, bool) { + _typ, ok := thriftAPIAssistTypes[typ] + return _typ, ok +} + +func IsValidAPIAssistType(typ consts.APIFileAssistType) bool { + _, ok := thriftAPIAssistTypes[typ] + return ok +} diff --git a/backend/crossdomain/contract/plugin/convert/plugin.go b/backend/crossdomain/contract/plugin/convert/plugin.go new file mode 100644 index 000000000..76d2556fa --- /dev/null +++ b/backend/crossdomain/contract/plugin/convert/plugin.go @@ -0,0 +1,44 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package convert + +import ( + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" +) + +var pluginTypes = map[common.PluginType]consts.PluginType{ + common.PluginType_PLUGIN: consts.PluginTypeOfCloud, +} + +func ToPluginType(typ common.PluginType) (consts.PluginType, bool) { + _type, ok := pluginTypes[typ] + return _type, ok +} + +var thriftPluginTypes = func() map[consts.PluginType]common.PluginType { + types := make(map[consts.PluginType]common.PluginType, len(pluginTypes)) + for k, v := range pluginTypes { + types[v] = k + } + return types +}() + +func ToThriftPluginType(typ consts.PluginType) (common.PluginType, bool) { + _type, ok := thriftPluginTypes[typ] + return _type, ok +} diff --git a/backend/crossdomain/contract/plugin/dto/convert.go b/backend/crossdomain/contract/plugin/dto/convert.go deleted file mode 100644 index 7e1597da1..000000000 --- a/backend/crossdomain/contract/plugin/dto/convert.go +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dto - -import ( - "net/http" - - "github.com/getkin/kin-openapi/openapi3" - - common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" -) - -var httpParamLocations = map[common.ParameterLocation]HTTPParamLocation{ - common.ParameterLocation_Path: ParamInPath, - common.ParameterLocation_Query: ParamInQuery, - common.ParameterLocation_Body: ParamInBody, - common.ParameterLocation_Header: ParamInHeader, -} - -func ToHTTPParamLocation(loc common.ParameterLocation) (HTTPParamLocation, bool) { - _loc, ok := httpParamLocations[loc] - return _loc, ok -} - -var thriftHTTPParamLocations = func() map[HTTPParamLocation]common.ParameterLocation { - locations := make(map[HTTPParamLocation]common.ParameterLocation, len(httpParamLocations)) - for k, v := range httpParamLocations { - locations[v] = k - } - return locations -}() - -func ToThriftHTTPParamLocation(loc HTTPParamLocation) (common.ParameterLocation, bool) { - _loc, ok := thriftHTTPParamLocations[loc] - return _loc, ok -} - -var openapiTypes = map[common.ParameterType]string{ - common.ParameterType_String: openapi3.TypeString, - common.ParameterType_Integer: openapi3.TypeInteger, - common.ParameterType_Number: openapi3.TypeNumber, - common.ParameterType_Object: openapi3.TypeObject, - common.ParameterType_Array: openapi3.TypeArray, - common.ParameterType_Bool: openapi3.TypeBoolean, -} - -func ToOpenapiParamType(typ common.ParameterType) (string, bool) { - _typ, ok := openapiTypes[typ] - return _typ, ok -} - -var thriftParameterTypes = func() map[string]common.ParameterType { - types := make(map[string]common.ParameterType, len(openapiTypes)) - for k, v := range openapiTypes { - types[v] = k - } - return types -}() - -func ToThriftParamType(typ string) (common.ParameterType, bool) { - _typ, ok := thriftParameterTypes[typ] - return _typ, ok -} - -var apiAssistTypes = map[common.AssistParameterType]APIFileAssistType{ - common.AssistParameterType_DEFAULT: AssistTypeFile, - common.AssistParameterType_IMAGE: AssistTypeImage, - common.AssistParameterType_DOC: AssistTypeDoc, - common.AssistParameterType_PPT: AssistTypePPT, - common.AssistParameterType_CODE: AssistTypeCode, - common.AssistParameterType_EXCEL: AssistTypeExcel, - common.AssistParameterType_ZIP: AssistTypeZIP, - common.AssistParameterType_VIDEO: AssistTypeVideo, - common.AssistParameterType_AUDIO: AssistTypeAudio, - common.AssistParameterType_TXT: AssistTypeTXT, -} - -// TODO(fanlv): move to other package - -func ToAPIAssistType(typ common.AssistParameterType) (APIFileAssistType, bool) { - _typ, ok := apiAssistTypes[typ] - return _typ, ok -} - -var thriftAPIAssistTypes = func() map[APIFileAssistType]common.AssistParameterType { - types := make(map[APIFileAssistType]common.AssistParameterType, len(apiAssistTypes)) - for k, v := range apiAssistTypes { - types[v] = k - } - return types -}() - -func ToThriftAPIAssistType(typ APIFileAssistType) (common.AssistParameterType, bool) { - _typ, ok := thriftAPIAssistTypes[typ] - return _typ, ok -} - -func IsValidAPIAssistType(typ APIFileAssistType) bool { - _, ok := thriftAPIAssistTypes[typ] - return ok -} - -var httpMethods = map[common.APIMethod]string{ - common.APIMethod_GET: http.MethodGet, - common.APIMethod_POST: http.MethodPost, - common.APIMethod_PUT: http.MethodPut, - common.APIMethod_DELETE: http.MethodDelete, - common.APIMethod_PATCH: http.MethodPatch, -} - -var thriftAPIMethods = func() map[string]common.APIMethod { - methods := make(map[string]common.APIMethod, len(httpMethods)) - for k, v := range httpMethods { - methods[v] = k - } - return methods -}() - -func ToThriftAPIMethod(method string) (common.APIMethod, bool) { - _method, ok := thriftAPIMethods[method] - return _method, ok -} - -func ToHTTPMethod(method common.APIMethod) (string, bool) { - _method, ok := httpMethods[method] - return _method, ok -} - -var assistTypeToFormat = map[APIFileAssistType]string{ - AssistTypeFile: "file_url", - AssistTypeImage: "image_url", - AssistTypeDoc: "doc_url", - AssistTypePPT: "ppt_url", - AssistTypeCode: "code_url", - AssistTypeExcel: "excel_url", - AssistTypeZIP: "zip_url", - AssistTypeVideo: "video_url", - AssistTypeAudio: "audio_url", - AssistTypeTXT: "txt_url", -} - -func AssistTypeToFormat(typ APIFileAssistType) (string, bool) { - format, ok := assistTypeToFormat[typ] - return format, ok -} - -var formatToAssistType = func() map[string]APIFileAssistType { - types := make(map[string]APIFileAssistType, len(assistTypeToFormat)) - for k, v := range assistTypeToFormat { - types[v] = k - } - return types -}() - -func FormatToAssistType(format string) (APIFileAssistType, bool) { - typ, ok := formatToAssistType[format] - return typ, ok -} - -var assistTypeToThriftFormat = map[APIFileAssistType]common.PluginParamTypeFormat{ - AssistTypeFile: common.PluginParamTypeFormat_FileUrl, - AssistTypeImage: common.PluginParamTypeFormat_ImageUrl, - AssistTypeDoc: common.PluginParamTypeFormat_DocUrl, - AssistTypePPT: common.PluginParamTypeFormat_PptUrl, - AssistTypeCode: common.PluginParamTypeFormat_CodeUrl, - AssistTypeExcel: common.PluginParamTypeFormat_ExcelUrl, - AssistTypeZIP: common.PluginParamTypeFormat_ZipUrl, - AssistTypeVideo: common.PluginParamTypeFormat_VideoUrl, - AssistTypeAudio: common.PluginParamTypeFormat_AudioUrl, - AssistTypeTXT: common.PluginParamTypeFormat_TxtUrl, -} - -func AssistTypeToThriftFormat(typ APIFileAssistType) (common.PluginParamTypeFormat, bool) { - format, ok := assistTypeToThriftFormat[typ] - return format, ok -} - -var authTypes = map[common.AuthorizationType]AuthzType{ - common.AuthorizationType_None: AuthzTypeOfNone, - common.AuthorizationType_Service: AuthzTypeOfService, - common.AuthorizationType_OAuth: AuthzTypeOfOAuth, - common.AuthorizationType_Standard: AuthzTypeOfOAuth, // deprecated, the same as OAuth -} - -func ToAuthType(typ common.AuthorizationType) (AuthzType, bool) { - _type, ok := authTypes[typ] - return _type, ok -} - -var thriftAuthTypes = func() map[AuthzType]common.AuthorizationType { - types := make(map[AuthzType]common.AuthorizationType, len(authTypes)) - for k, v := range authTypes { - if v == AuthzTypeOfOAuth { - types[v] = common.AuthorizationType_OAuth - } else { - types[v] = k - } - } - return types -}() - -func ToThriftAuthType(typ AuthzType) (common.AuthorizationType, bool) { - _type, ok := thriftAuthTypes[typ] - return _type, ok -} - -var subAuthTypes = map[int32]AuthzSubType{ - int32(common.ServiceAuthSubType_ApiKey): AuthzSubTypeOfServiceAPIToken, - int32(common.ServiceAuthSubType_OAuthAuthorizationCode): AuthzSubTypeOfOAuthAuthorizationCode, -} - -func ToAuthSubType(typ int32) (AuthzSubType, bool) { - _type, ok := subAuthTypes[typ] - return _type, ok -} - -var thriftSubAuthTypes = func() map[AuthzSubType]int32 { - types := make(map[AuthzSubType]int32, len(subAuthTypes)) - for k, v := range subAuthTypes { - types[v] = int32(k) - } - return types -}() - -func ToThriftAuthSubType(typ AuthzSubType) (int32, bool) { - _type, ok := thriftSubAuthTypes[typ] - return _type, ok -} - -var pluginTypes = map[common.PluginType]PluginType{ - common.PluginType_PLUGIN: PluginTypeOfCloud, -} - -func ToPluginType(typ common.PluginType) (PluginType, bool) { - _type, ok := pluginTypes[typ] - return _type, ok -} - -var thriftPluginTypes = func() map[PluginType]common.PluginType { - types := make(map[PluginType]common.PluginType, len(pluginTypes)) - for k, v := range pluginTypes { - types[v] = k - } - return types -}() - -func ToThriftPluginType(typ PluginType) (common.PluginType, bool) { - _type, ok := thriftPluginTypes[typ] - return _type, ok -} - -var apiAuthModes = map[common.PluginToolAuthType]ToolAuthMode{ - common.PluginToolAuthType_Required: ToolAuthModeOfRequired, - common.PluginToolAuthType_Supported: ToolAuthModeOfSupported, - common.PluginToolAuthType_Disable: ToolAuthModeOfDisabled, -} - -func ToAPIAuthMode(mode common.PluginToolAuthType) (ToolAuthMode, bool) { - _mode, ok := apiAuthModes[mode] - return _mode, ok -} - -var thriftAPIAuthModes = func() map[ToolAuthMode]common.PluginToolAuthType { - modes := make(map[ToolAuthMode]common.PluginToolAuthType, len(apiAuthModes)) - for k, v := range apiAuthModes { - modes[v] = k - } - return modes -}() - -func ToThriftAPIAuthMode(mode ToolAuthMode) (common.PluginToolAuthType, bool) { - _mode, ok := thriftAPIAuthModes[mode] - return _mode, ok -} diff --git a/backend/crossdomain/contract/plugin/model/default.go b/backend/crossdomain/contract/plugin/model/default.go new file mode 100644 index 000000000..0da0bf81e --- /dev/null +++ b/backend/crossdomain/contract/plugin/model/default.go @@ -0,0 +1,98 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "net/http" + "strconv" + + "github.com/getkin/kin-openapi/openapi3" + + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" +) + +func NewDefaultPluginManifest() *PluginManifest { + return &PluginManifest{ + SchemaVersion: "v1", + API: APIDesc{ + Type: consts.PluginTypeOfCloud, + }, + Auth: &AuthV2{ + Type: consts.AuthzTypeOfNone, + }, + CommonParams: map[consts.HTTPParamLocation][]*common.CommonParamSchema{ + consts.ParamInBody: {}, + consts.ParamInHeader: { + { + Name: "User-Agent", + Value: "Coze/1.0", + }, + }, + consts.ParamInQuery: {}, + }, + } +} + +func NewDefaultOpenapiDoc() *Openapi3T { + return &Openapi3T{ + OpenAPI: "3.0.1", + Info: &openapi3.Info{ + Version: "v1", + }, + Paths: openapi3.Paths{}, + Servers: openapi3.Servers{}, + } +} + +func DefaultOpenapi3Responses() openapi3.Responses { + return openapi3.Responses{ + strconv.Itoa(http.StatusOK): { + Value: &openapi3.Response{ + Description: ptr.Of("description is required"), + Content: openapi3.Content{ + consts.MediaTypeJson: &openapi3.MediaType{ + Schema: &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: openapi3.TypeObject, + Properties: map[string]*openapi3.SchemaRef{}, + }, + }, + }, + }, + }, + }, + } +} + +func DefaultOpenapi3RequestBody() *openapi3.RequestBodyRef { + return &openapi3.RequestBodyRef{ + Value: &openapi3.RequestBody{ + Content: map[string]*openapi3.MediaType{ + consts.MediaTypeJson: { + Schema: &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: openapi3.TypeObject, + Properties: map[string]*openapi3.SchemaRef{}, + }, + }, + }, + }, + }, + } +} diff --git a/backend/crossdomain/contract/plugin/dto/openapi.go b/backend/crossdomain/contract/plugin/model/openapi.go similarity index 94% rename from backend/crossdomain/contract/plugin/dto/openapi.go rename to backend/crossdomain/contract/plugin/model/openapi.go index c3ec14ca3..20c314d2a 100644 --- a/backend/crossdomain/contract/plugin/dto/openapi.go +++ b/backend/crossdomain/contract/plugin/model/openapi.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dto +package model import ( "context" @@ -23,15 +23,15 @@ import ( "strconv" "strings" + "github.com/cloudwego/eino/schema" "github.com/getkin/kin-openapi/openapi3" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/errno" - - "github.com/cloudwego/eino/schema" ) type Openapi3T openapi3.T @@ -336,7 +336,7 @@ func validateOpenapi3Parameters(params openapi3.Parameters) (err error) { return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "parameter location is required")) } - if paramVal.In == string(ParamInBody) { + if paramVal.In == string(consts.ParamInBody) { return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey, "the location of parameter '%s' cannot be 'body'", paramVal.Name)) } @@ -360,20 +360,13 @@ func validateOpenapi3Parameters(params openapi3.Parameters) (err error) { } // MIME Type -const ( - MediaTypeJson = "application/json" - MediaTypeProblemJson = "application/problem+json" - MediaTypeFormURLEncoded = "application/x-www-form-urlencoded" - MediaTypeXYaml = "application/x-yaml" - MediaTypeYaml = "application/yaml" -) var mediaTypeArray = []string{ - MediaTypeJson, - MediaTypeProblemJson, - MediaTypeFormURLEncoded, - MediaTypeXYaml, - MediaTypeYaml, + consts.MediaTypeJson, + consts.MediaTypeProblemJson, + consts.MediaTypeFormURLEncoded, + consts.MediaTypeXYaml, + consts.MediaTypeYaml, } func validateOpenapi3Responses(responses openapi3.Responses) (err error) { @@ -406,7 +399,7 @@ func validateOpenapi3Responses(responses openapi3.Responses) (err error) { return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "response only supports 'application/json' media type")) } - mType, ok := resp.Value.Content[MediaTypeJson] + mType, ok := resp.Value.Content[consts.MediaTypeJson] if !ok || mType == nil { return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "response only supports 'application/json' media type")) @@ -436,11 +429,11 @@ func disabledParam(schemaVal *openapi3.Schema) bool { } globalDisable, localDisable := false, false - if v, ok := schemaVal.Extensions[APISchemaExtendLocalDisable]; ok { + if v, ok := schemaVal.Extensions[consts.APISchemaExtendLocalDisable]; ok { localDisable = v.(bool) } - if v, ok := schemaVal.Extensions[APISchemaExtendGlobalDisable]; ok { + if v, ok := schemaVal.Extensions[consts.APISchemaExtendGlobalDisable]; ok { globalDisable = v.(bool) } @@ -453,11 +446,11 @@ func (op *Openapi3Operation) GetReqBodySchema() (string, *openapi3.SchemaRef) { } var contentTypeArray = []string{ - MediaTypeJson, - MediaTypeProblemJson, - MediaTypeFormURLEncoded, - MediaTypeXYaml, - MediaTypeYaml, + consts.MediaTypeJson, + consts.MediaTypeProblemJson, + consts.MediaTypeFormURLEncoded, + consts.MediaTypeXYaml, + consts.MediaTypeYaml, } for _, ct := range contentTypeArray { diff --git a/backend/crossdomain/contract/plugin/dto/option.go b/backend/crossdomain/contract/plugin/model/option.go similarity index 78% rename from backend/crossdomain/contract/plugin/dto/option.go rename to backend/crossdomain/contract/plugin/model/option.go index 0d55372d0..2d031c4a4 100644 --- a/backend/crossdomain/contract/plugin/dto/option.go +++ b/backend/crossdomain/contract/plugin/model/option.go @@ -14,7 +14,9 @@ * limitations under the License. */ -package dto +package model + +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" type ExecuteToolOption struct { ProjectInfo *ProjectInfo @@ -23,7 +25,7 @@ type ExecuteToolOption struct { ToolVersion string Operation *Openapi3Operation - InvalidRespProcessStrategy InvalidResponseProcessStrategy + InvalidRespProcessStrategy consts.InvalidResponseProcessStrategy ConversationID int64 } @@ -31,9 +33,9 @@ type ExecuteToolOption struct { type ExecuteToolOpt func(o *ExecuteToolOption) type ProjectInfo struct { - ProjectID int64 // agentID or appID - ProjectVersion *string // if version si nil, use latest version - ProjectType ProjectType // agent or app + ProjectID int64 // agentID or appID + ProjectVersion *string // if version is nil, use latest version + ProjectType consts.ProjectType // agent or app ConnectorID int64 } @@ -56,7 +58,7 @@ func WithOpenapiOperation(op *Openapi3Operation) ExecuteToolOpt { } } -func WithInvalidRespProcessStrategy(strategy InvalidResponseProcessStrategy) ExecuteToolOpt { +func WithInvalidRespProcessStrategy(strategy consts.InvalidResponseProcessStrategy) ExecuteToolOpt { return func(o *ExecuteToolOption) { o.InvalidRespProcessStrategy = strategy } diff --git a/backend/crossdomain/contract/plugin/dto/plugin.go b/backend/crossdomain/contract/plugin/model/plugin.go similarity index 50% rename from backend/crossdomain/contract/plugin/dto/plugin.go rename to backend/crossdomain/contract/plugin/model/plugin.go index 8a12c4f27..d9ee5dc06 100644 --- a/backend/crossdomain/contract/plugin/dto/plugin.go +++ b/backend/crossdomain/contract/plugin/model/plugin.go @@ -14,11 +14,9 @@ * limitations under the License. */ -package dto +package model import ( - "github.com/getkin/kin-openapi/openapi3" - api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" ) @@ -27,11 +25,6 @@ type VersionPlugin struct { Version string } -type VersionTool struct { - ToolID int64 - Version string -} - type MGetPluginLatestVersionResponse struct { Versions map[int64]string // pluginID vs version } @@ -55,83 +48,9 @@ type PluginInfo struct { OpenapiDoc *Openapi3T } -func (p PluginInfo) SetName(name string) { - if p.Manifest == nil || p.OpenapiDoc == nil { - return - } - p.Manifest.NameForModel = name - p.Manifest.NameForHuman = name - p.OpenapiDoc.Info.Title = name -} - -func (p PluginInfo) GetName() string { - if p.Manifest == nil { - return "" - } - return p.Manifest.NameForHuman -} - -func (p PluginInfo) GetDesc() string { - if p.Manifest == nil { - return "" - } - return p.Manifest.DescriptionForHuman -} - -func (p PluginInfo) GetAuthInfo() *AuthV2 { - if p.Manifest == nil { - return nil - } - return p.Manifest.Auth -} - -func (p PluginInfo) IsOfficial() bool { - return p.RefProductID != nil -} - -func (p PluginInfo) GetIconURI() string { - if p.IconURI == nil { - return "" - } - return *p.IconURI -} - -func (p PluginInfo) Published() bool { - return p.Version != nil -} - -type VersionAgentTool struct { - ToolName *string - ToolID int64 - - AgentVersion *string -} - -type MGetAgentToolsRequest struct { - AgentID int64 - SpaceID int64 - IsDraft bool - - VersionAgentTools []VersionAgentTool -} - -type ExecuteToolRequest struct { - UserID string - PluginID int64 - ToolID int64 - ExecDraftTool bool // if true, execute draft tool - ExecScene ExecuteScene - - ArgumentsInJson string -} - -type ExecuteToolResponse struct { - Tool *ToolInfo - Request string - TrimmedResp string - RawResp string - - RespSchema openapi3.Responses +type ToolExample struct { + RequestExample string + ResponseExample string } type PublishPluginRequest struct { @@ -158,12 +77,3 @@ type CheckCanPublishPluginsRequest struct { type CheckCanPublishPluginsResponse struct { InvalidPlugins []*PluginInfo } - -type ToolInterruptEvent struct { - Event InterruptEventType - ToolNeedOAuth *ToolNeedOAuthInterruptEvent -} - -type ToolNeedOAuthInterruptEvent struct { - Message string -} diff --git a/backend/crossdomain/contract/plugin/dto/plugin_manifest.go b/backend/crossdomain/contract/plugin/model/plugin_manifest.go similarity index 86% rename from backend/crossdomain/contract/plugin/dto/plugin_manifest.go rename to backend/crossdomain/contract/plugin/model/plugin_manifest.go index 1b960d21e..b45f4f56a 100644 --- a/backend/crossdomain/contract/plugin/dto/plugin_manifest.go +++ b/backend/crossdomain/contract/plugin/model/plugin_manifest.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dto +package model import ( "encoding/json" @@ -23,6 +23,7 @@ import ( "strings" api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/types/errno" @@ -31,15 +32,15 @@ import ( ) type PluginManifest struct { - SchemaVersion string `json:"schema_version" yaml:"schema_version"` - NameForModel string `json:"name_for_model" yaml:"name_for_model"` - NameForHuman string `json:"name_for_human" yaml:"name_for_human"` - DescriptionForModel string `json:"description_for_model" yaml:"description_for_model"` - DescriptionForHuman string `json:"description_for_human" yaml:"description_for_human"` - Auth *AuthV2 `json:"auth" yaml:"auth"` - LogoURL string `json:"logo_url" yaml:"logo_url"` - API APIDesc `json:"api" yaml:"api"` - CommonParams map[HTTPParamLocation][]*api.CommonParamSchema `json:"common_params" yaml:"common_params"` + SchemaVersion string `json:"schema_version" yaml:"schema_version"` + NameForModel string `json:"name_for_model" yaml:"name_for_model"` + NameForHuman string `json:"name_for_human" yaml:"name_for_human"` + DescriptionForModel string `json:"description_for_model" yaml:"description_for_model"` + DescriptionForHuman string `json:"description_for_human" yaml:"description_for_human"` + Auth *AuthV2 `json:"auth" yaml:"auth"` + LogoURL string `json:"logo_url" yaml:"logo_url"` + API APIDesc `json:"api" yaml:"api"` + CommonParams map[consts.HTTPParamLocation][]*api.CommonParamSchema `json:"common_params" yaml:"common_params"` } func (mf *PluginManifest) Copy() (*PluginManifest, error) { @@ -111,7 +112,7 @@ func (mf *PluginManifest) Validate(skipAuthPayload bool) (err error) { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "description for human is required")) } - if mf.API.Type != PluginTypeOfCloud { + if mf.API.Type != consts.PluginTypeOfCloud { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, "invalid api type '%s'", mf.API.Type)) } @@ -122,10 +123,10 @@ func (mf *PluginManifest) Validate(skipAuthPayload bool) (err error) { } for loc := range mf.CommonParams { - if loc != ParamInBody && - loc != ParamInHeader && - loc != ParamInQuery && - loc != ParamInPath { + if loc != consts.ParamInBody && + loc != consts.ParamInHeader && + loc != consts.ParamInQuery && + loc != consts.ParamInPath { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, "invalid location '%s' in common params", loc)) } @@ -154,14 +155,14 @@ func (mf *PluginManifest) validateAuthInfo(skipAuthPayload bool) (err error) { "auth type is required")) } - if mf.Auth.Type != AuthzTypeOfNone && - mf.Auth.Type != AuthzTypeOfOAuth && - mf.Auth.Type != AuthzTypeOfService { + if mf.Auth.Type != consts.AuthzTypeOfNone && + mf.Auth.Type != consts.AuthzTypeOfOAuth && + mf.Auth.Type != consts.AuthzTypeOfService { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, "invalid auth type '%s'", mf.Auth.Type)) } - if mf.Auth.Type == AuthzTypeOfNone { + if mf.Auth.Type == consts.AuthzTypeOfNone { return nil } @@ -171,11 +172,11 @@ func (mf *PluginManifest) validateAuthInfo(skipAuthPayload bool) (err error) { } switch mf.Auth.SubType { - case AuthzSubTypeOfServiceAPIToken: + case consts.AuthzSubTypeOfServiceAPIToken: err = mf.validateServiceToken(skipAuthPayload) //case AuthzSubTypeOfOAuthClientCredentials: // err = mf.validateClientCredentials() - case AuthzSubTypeOfOAuthAuthorizationCode: + case consts.AuthzSubTypeOfOAuthAuthorizationCode: err = mf.validateAuthCode(skipAuthPayload) default: return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, @@ -212,8 +213,8 @@ func (mf *PluginManifest) validateServiceToken(skipAuthPayload bool) (err error) "key is required")) } - loc := HTTPParamLocation(strings.ToLower(string(apiToken.Location))) - if loc != ParamInHeader && loc != ParamInQuery { + loc := consts.HTTPParamLocation(strings.ToLower(string(apiToken.Location))) + if loc != consts.ParamInHeader && loc != consts.ParamInQuery { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, "invalid location '%s'", apiToken.Location)) } @@ -281,7 +282,7 @@ func (mf *PluginManifest) validateAuthCode(skipAuthPayload bool) (err error) { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "client secret is required")) } - if authCode.AuthorizationContentType != MediaTypeJson { + if authCode.AuthorizationContentType != consts.MediaTypeJson { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "authorization content type must be 'application/json'")) } @@ -335,9 +336,9 @@ type Auth struct { } type AuthV2 struct { - Type AuthzType `json:"type" yaml:"type"` - SubType AuthzSubType `json:"sub_type" yaml:"sub_type"` - Payload string `json:"payload" yaml:"payload"` + Type consts.AuthzType `json:"type" yaml:"type"` + SubType consts.AuthzSubType `json:"sub_type" yaml:"sub_type"` + Payload string `json:"payload" yaml:"payload"` // service AuthOfAPIToken *AuthOfAPIToken `json:"-"` @@ -354,8 +355,8 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error { "invalid plugin manifest json")) } - au.Type = AuthzType(auth.Type) - au.SubType = AuthzSubType(auth.SubType) + au.Type = consts.AuthzType(auth.Type) + au.SubType = consts.AuthzSubType(auth.SubType) if au.Type == "" { return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, @@ -368,17 +369,17 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error { secret = encrypt.DefaultAuthSecret } - payload_, err := encrypt.DecryptByAES(auth.Payload, secret) - if err == nil { + payload_, eErr := encrypt.DecryptByAES(auth.Payload, secret) + if eErr == nil { auth.Payload = string(payload_) } } switch au.Type { - case AuthzTypeOfNone: - case AuthzTypeOfOAuth: + case consts.AuthzTypeOfNone: + case consts.AuthzTypeOfOAuth: err = au.unmarshalOAuth(auth) - case AuthzTypeOfService: + case consts.AuthzTypeOfService: err = au.unmarshalService(auth) default: return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, @@ -393,15 +394,15 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error { func (au *AuthV2) unmarshalService(auth *Auth) (err error) { if au.SubType == "" && au.Payload == "" { // Compatible with old data - au.SubType = AuthzSubTypeOfServiceAPIToken + au.SubType = consts.AuthzSubTypeOfServiceAPIToken } var payload []byte - if au.SubType == AuthzSubTypeOfServiceAPIToken { + if au.SubType == consts.AuthzSubTypeOfServiceAPIToken { if len(auth.ServiceToken) > 0 { au.AuthOfAPIToken = &AuthOfAPIToken{ - Location: HTTPParamLocation(strings.ToLower(auth.Location)), + Location: consts.HTTPParamLocation(strings.ToLower(auth.Location)), Key: auth.Key, ServiceToken: auth.ServiceToken, } @@ -433,12 +434,12 @@ func (au *AuthV2) unmarshalService(auth *Auth) (err error) { func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) { if au.SubType == "" { // Compatible with old data - au.SubType = AuthzSubTypeOfOAuthAuthorizationCode + au.SubType = consts.AuthzSubTypeOfOAuthAuthorizationCode } var payload []byte - if au.SubType == AuthzSubTypeOfOAuthAuthorizationCode { + if au.SubType == consts.AuthzSubTypeOfOAuthAuthorizationCode { if len(auth.ClientSecret) > 0 { au.AuthOfOAuthAuthorizationCode = &OAuthAuthorizationCodeConfig{ ClientID: auth.ClientID, @@ -464,7 +465,7 @@ func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) { } } - if au.SubType == AuthzSubTypeOfOAuthClientCredentials { + if au.SubType == consts.AuthzSubTypeOfOAuthClientCredentials { oauth := &OAuthClientCredentialsConfig{} err = json.Unmarshal([]byte(auth.Payload), oauth) if err != nil { @@ -492,7 +493,7 @@ func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) { type AuthOfAPIToken struct { // Location is the location of the parameter. // It can be "header" or "query". - Location HTTPParamLocation `json:"location"` + Location consts.HTTPParamLocation `json:"location"` // Key is the name of the parameter. Key string `json:"key"` // ServiceToken is the simple authorization information for the service. @@ -520,5 +521,5 @@ type OAuthClientCredentialsConfig struct { } type APIDesc struct { - Type PluginType `json:"type" validate:"required"` + Type consts.PluginType `json:"type" validate:"required"` } diff --git a/backend/crossdomain/contract/plugin/model/plugin_method.go b/backend/crossdomain/contract/plugin/model/plugin_method.go new file mode 100644 index 000000000..ff244caec --- /dev/null +++ b/backend/crossdomain/contract/plugin/model/plugin_method.go @@ -0,0 +1,113 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package model + +import ( + "context" + + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/pkg/sonic" +) + +func (p PluginInfo) GetToolExample(ctx context.Context, toolName string) *ToolExample { + if p.OpenapiDoc == nil || + p.OpenapiDoc.Components == nil || + len(p.OpenapiDoc.Components.Examples) == 0 { + return nil + } + example, ok := p.OpenapiDoc.Components.Examples[toolName] + if !ok { + return nil + } + if example.Value == nil || example.Value.Value == nil { + return nil + } + + val, ok := example.Value.Value.(map[string]any) + if !ok { + return nil + } + + reqExample, ok := val["ReqExample"] + if !ok { + return nil + } + reqExampleStr, err := sonic.MarshalString(reqExample) + if err != nil { + logs.CtxErrorf(ctx, "marshal request example failed, err=%v", err) + return nil + } + + respExample, ok := val["RespExample"] + if !ok { + return nil + } + respExampleStr, err := sonic.MarshalString(respExample) + if err != nil { + logs.CtxErrorf(ctx, "marshal response example failed, err=%v", err) + return nil + } + + return &ToolExample{ + RequestExample: reqExampleStr, + ResponseExample: respExampleStr, + } +} + +func (p PluginInfo) GetName() string { + if p.Manifest == nil { + return "" + } + return p.Manifest.NameForHuman +} + +func (p PluginInfo) GetVersion() string { + return ptr.FromOrDefault(p.Version, "") +} + +func (p PluginInfo) GetAPPID() int64 { + return ptr.FromOrDefault(p.APPID, 0) +} + +func (p PluginInfo) GetDesc() string { + if p.Manifest == nil { + return "" + } + return p.Manifest.DescriptionForHuman +} + +func (p PluginInfo) GetAuthInfo() *AuthV2 { + if p.Manifest == nil { + return nil + } + return p.Manifest.Auth +} + +func (p PluginInfo) IsOfficial() bool { + return p.RefProductID != nil +} + +func (p PluginInfo) GetIconURI() string { + if p.IconURI == nil { + return "" + } + return *p.IconURI +} + +func (p PluginInfo) Published() bool { + return p.Version != nil +} diff --git a/backend/crossdomain/contract/plugin/dto/toolinfo.go b/backend/crossdomain/contract/plugin/model/toolinfo.go similarity index 85% rename from backend/crossdomain/contract/plugin/dto/toolinfo.go rename to backend/crossdomain/contract/plugin/model/toolinfo.go index f673e8656..7d512e498 100644 --- a/backend/crossdomain/contract/plugin/dto/toolinfo.go +++ b/backend/crossdomain/contract/plugin/model/toolinfo.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dto +package model import ( "fmt" @@ -27,6 +27,8 @@ import ( productAPI "github.com/coze-dev/coze-studio/backend/api/model/marketplace/product_public_api" "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" ) @@ -38,7 +40,7 @@ type ToolInfo struct { UpdatedAt int64 Version *string - ActivatedStatus *ActivatedStatus + ActivatedStatus *consts.ActivatedStatus DebugStatus *common.APIDebugStatus Method *string @@ -64,8 +66,16 @@ func (t ToolInfo) GetVersion() string { return ptr.FromOrDefault(t.Version, "") } -func (t ToolInfo) GetActivatedStatus() ActivatedStatus { - return ptr.FromOrDefault(t.ActivatedStatus, ActivateTool) +func (t ToolInfo) GetActivatedStatus() consts.ActivatedStatus { + return ptr.FromOrDefault(t.ActivatedStatus, consts.ActivateTool) +} + +func (t *ToolInfo) IsDeactivated() bool { + return t.GetActivatedStatus() == consts.DeactivateTool +} + +func (t ToolInfo) IsDebugging() bool { + return t.GetDebugStatus() == common.APIDebugStatus_DebugWaiting } func (t ToolInfo) GetSubURL() string { @@ -91,9 +101,9 @@ func (t ToolInfo) GetResponseOpenapiSchema() (*openapi3.Schema, error) { return nil, fmt.Errorf("response status '200' not found") } - mType, ok := resp.Value.Content[MediaTypeJson] // only support application/json + mType, ok := resp.Value.Content[consts.MediaTypeJson] // only support application/json if !ok || mType == nil || mType.Schema == nil || mType.Schema.Value == nil { - return nil, fmt.Errorf("media type '%s' not found in response", MediaTypeJson) + return nil, fmt.Errorf("media type '%s' not found in response", consts.MediaTypeJson) } return mType.Schema.Value, nil @@ -134,7 +144,7 @@ func (t ToolInfo) ToRespAPIParameter() ([]*common.APIParameter, error) { paramMeta := paramMetaInfo{ name: subParamName, desc: prop.Value.Description, - location: string(ParamInBody), + location: string(consts.ParamInBody), required: required[subParamName], } apiParam, err := toAPIParameter(paramMeta, prop.Value) @@ -220,7 +230,7 @@ func (t ToolInfo) ToReqAPIParameter() ([]*common.APIParameter, error) { paramMeta := paramMetaInfo{ name: subParamName, desc: prop.Value.Description, - location: string(ParamInBody), + location: string(consts.ParamInBody), required: required[subParamName], } apiParam, err := toAPIParameter(paramMeta, prop.Value) @@ -239,14 +249,14 @@ func (t ToolInfo) ToReqAPIParameter() ([]*common.APIParameter, error) { func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIParameter, error) { if sc == nil { - return nil, fmt.Errorf("schema is requred") + return nil, fmt.Errorf("schema is required") } - apiType, ok := ToThriftParamType(strings.ToLower(sc.Type)) + apiType, ok := convert.ToThriftParamType(strings.ToLower(sc.Type)) if !ok { return nil, fmt.Errorf("the type '%s' of filed '%s' is invalid", sc.Type, paramMeta.name) } - location, ok := ToThriftHTTPParamLocation(HTTPParamLocation(paramMeta.location)) + location, ok := convert.ToThriftHTTPParamLocation(consts.HTTPParamLocation(paramMeta.location)) if !ok { return nil, fmt.Errorf("the location '%s' of field '%s' is invalid", paramMeta.location, paramMeta.name) } @@ -267,28 +277,28 @@ func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIPa } if sc.Format != "" { - aType, ok := FormatToAssistType(sc.Format) + aType, ok := convert.FormatToAssistType(sc.Format) if !ok { return nil, fmt.Errorf("the format '%s' of field '%s' is invalid", sc.Format, paramMeta.name) } - _aType, ok := ToThriftAPIAssistType(aType) + _aType, ok := convert.ToThriftAPIAssistType(aType) if !ok { return nil, fmt.Errorf("assist type '%s' of field '%s' is invalid", aType, paramMeta.name) } apiParam.AssistType = ptr.Of(_aType) } - if v, ok := sc.Extensions[APISchemaExtendGlobalDisable]; ok { + if v, ok := sc.Extensions[consts.APISchemaExtendGlobalDisable]; ok { if disable, ok := v.(bool); ok { apiParam.GlobalDisable = disable } } - if v, ok := sc.Extensions[APISchemaExtendLocalDisable]; ok { + if v, ok := sc.Extensions[consts.APISchemaExtendLocalDisable]; ok { if disable, ok := v.(bool); ok { apiParam.LocalDisable = disable } } - if v, ok := sc.Extensions[APISchemaExtendVariableRef]; ok { + if v, ok := sc.Extensions[consts.APISchemaExtendVariableRef]; ok { if ref, ok := v.(string); ok { apiParam.VariableRef = ptr.Of(ref) apiParam.DefaultParamSource = ptr.Of(common.DefaultParamSource_Variable) @@ -392,12 +402,12 @@ func (t ToolInfo) ToPluginParameters() ([]*common.PluginParameter, error) { } var assistType *common.PluginParamTypeFormat - if v, ok := schemaVal.Extensions[APISchemaExtendAssistType]; ok { + if v, ok := schemaVal.Extensions[consts.APISchemaExtendAssistType]; ok { _v, ok := v.(string) if !ok { continue } - f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v)) + f, ok := convert.AssistTypeToThriftFormat(consts.APIFileAssistType(_v)) if !ok { return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramVal.Name) } @@ -468,9 +478,9 @@ func toPluginParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.Pl } var assistType *common.PluginParamTypeFormat - if v, ok := sc.Extensions[APISchemaExtendAssistType]; ok { + if v, ok := sc.Extensions[consts.APISchemaExtendAssistType]; ok { if _v, ok := v.(string); ok { - f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v)) + f, ok := convert.AssistTypeToThriftFormat(consts.APIFileAssistType(_v)) if !ok { return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramMeta.name) } @@ -555,7 +565,7 @@ func (t ToolInfo) ToToolParameters() ([]*productAPI.ToolParameter, error) { toToolParams = func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error) { params := make([]*productAPI.ToolParameter, 0, len(apiParams)) for _, apiParam := range apiParams { - typ, _ := ToOpenapiParamType(apiParam.Type) + typ, _ := convert.ToOpenapiParamType(apiParam.Type) toolParam := &productAPI.ToolParameter{ Name: apiParam.Name, Description: apiParam.Desc, @@ -580,3 +590,51 @@ func (t ToolInfo) ToToolParameters() ([]*productAPI.ToolParameter, error) { return toToolParams(apiParams) } + +type VersionTool struct { + ToolID int64 + Version string +} + +type VersionAgentTool struct { + ToolName *string + ToolID int64 + + AgentVersion *string +} + +type MGetAgentToolsRequest struct { + AgentID int64 + SpaceID int64 + IsDraft bool + + VersionAgentTools []VersionAgentTool +} + +type ExecuteToolRequest struct { + UserID string + PluginID int64 + ToolID int64 + ExecDraftTool bool // if true, execute draft tool + ExecScene consts.ExecuteScene + + ArgumentsInJson string +} + +type ExecuteToolResponse struct { + Tool *ToolInfo + Request string + TrimmedResp string + RawResp string + + RespSchema openapi3.Responses +} + +type ToolInterruptEvent struct { + Event consts.InterruptEventType + ToolNeedOAuth *ToolNeedOAuthInterruptEvent +} + +type ToolNeedOAuthInterruptEvent struct { + Message string +} diff --git a/backend/crossdomain/contract/plugin/plugin.go b/backend/crossdomain/contract/plugin/plugin.go index aacc42055..3522ee065 100644 --- a/backend/crossdomain/contract/plugin/plugin.go +++ b/backend/crossdomain/contract/plugin/plugin.go @@ -22,26 +22,22 @@ import ( "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" ) //go:generate mockgen -destination pluginmock/plugin_mock.go --package pluginmock -source plugin.go type PluginService interface { - MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*model.PluginInfo, err error) - MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *model.MGetPluginLatestVersionResponse, err error) BindAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) - DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) MGetAgentTools(ctx context.Context, req *model.MGetAgentToolsRequest) (tools []*model.ToolInfo, err error) ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...model.ExecuteToolOpt) (resp *model.ExecuteToolResponse, err error) - PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) (err error) - DeleteDraftPlugin(ctx context.Context, PluginID int64) (err error) - PublishPlugin(ctx context.Context, req *model.PublishPluginRequest) (err error) PublishAPPPlugins(ctx context.Context, req *model.PublishAPPPluginsRequest) (resp *model.PublishAPPPluginsResponse, err error) GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*model.PluginInfo, err error) + MGetDraftPlugins(ctx context.Context, pluginIDs []int64) (plugins []*model.PluginInfo, err error) + MGetOnlinePlugins(ctx context.Context, pluginIDs []int64) (plugins []*model.PluginInfo, err error) + MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*model.PluginInfo, err error) + MGetDraftTools(ctx context.Context, pluginIDs []int64) (tools []*model.ToolInfo, err error) + MGetOnlineTools(ctx context.Context, pluginIDs []int64) (tools []*model.ToolInfo, err error) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*model.ToolInfo, err error) - GetPluginToolsInfo(ctx context.Context, req *model.ToolsInfoRequest) (*model.ToolsInfoResponse, error) - GetPluginInvokableTools(ctx context.Context, req *model.ToolsInvokableRequest) (map[int64]InvokableTool, error) - ExecutePlugin(ctx context.Context, input map[string]any, pe *model.PluginEntity, toolID int64, cfg workflow.ExecuteConfig) (map[string]any, error) } type InvokableTool interface { diff --git a/backend/crossdomain/contract/plugin/pluginmock/plugin_mock.go b/backend/crossdomain/contract/plugin/pluginmock/plugin_mock.go index 981a6fd5a..f1f2d291b 100644 --- a/backend/crossdomain/contract/plugin/pluginmock/plugin_mock.go +++ b/backend/crossdomain/contract/plugin/pluginmock/plugin_mock.go @@ -15,8 +15,7 @@ import ( schema "github.com/cloudwego/eino/schema" workflow "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" - plugin0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" gomock "go.uber.org/mock/gomock" ) @@ -58,58 +57,15 @@ func (mr *MockPluginServiceMockRecorder) BindAgentTools(ctx, agentID, toolIDs an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindAgentTools", reflect.TypeOf((*MockPluginService)(nil).BindAgentTools), ctx, agentID, toolIDs) } -// DeleteDraftPlugin mocks base method. -func (m *MockPluginService) DeleteDraftPlugin(ctx context.Context, PluginID int64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteDraftPlugin", ctx, PluginID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteDraftPlugin indicates an expected call of DeleteDraftPlugin. -func (mr *MockPluginServiceMockRecorder) DeleteDraftPlugin(ctx, PluginID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDraftPlugin", reflect.TypeOf((*MockPluginService)(nil).DeleteDraftPlugin), ctx, PluginID) -} - -// DuplicateDraftAgentTools mocks base method. -func (m *MockPluginService) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DuplicateDraftAgentTools", ctx, fromAgentID, toAgentID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DuplicateDraftAgentTools indicates an expected call of DuplicateDraftAgentTools. -func (mr *MockPluginServiceMockRecorder) DuplicateDraftAgentTools(ctx, fromAgentID, toAgentID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DuplicateDraftAgentTools", reflect.TypeOf((*MockPluginService)(nil).DuplicateDraftAgentTools), ctx, fromAgentID, toAgentID) -} - -// ExecutePlugin mocks base method. -func (m *MockPluginService) ExecutePlugin(ctx context.Context, input map[string]any, pe *plugin.PluginEntity, toolID int64, cfg workflow.ExecuteConfig) (map[string]any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExecutePlugin", ctx, input, pe, toolID, cfg) - ret0, _ := ret[0].(map[string]any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ExecutePlugin indicates an expected call of ExecutePlugin. -func (mr *MockPluginServiceMockRecorder) ExecutePlugin(ctx, input, pe, toolID, cfg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecutePlugin", reflect.TypeOf((*MockPluginService)(nil).ExecutePlugin), ctx, input, pe, toolID, cfg) -} - // ExecuteTool mocks base method. -func (m *MockPluginService) ExecuteTool(ctx context.Context, req *plugin.ExecuteToolRequest, opts ...plugin.ExecuteToolOpt) (*plugin.ExecuteToolResponse, error) { +func (m *MockPluginService) ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...model.ExecuteToolOpt) (*model.ExecuteToolResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ExecuteTool", varargs...) - ret0, _ := ret[0].(*plugin.ExecuteToolResponse) + ret0, _ := ret[0].(*model.ExecuteToolResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -122,10 +78,10 @@ func (mr *MockPluginServiceMockRecorder) ExecuteTool(ctx, req any, opts ...any) } // GetAPPAllPlugins mocks base method. -func (m *MockPluginService) GetAPPAllPlugins(ctx context.Context, appID int64) ([]*plugin.PluginInfo, error) { +func (m *MockPluginService) GetAPPAllPlugins(ctx context.Context, appID int64) ([]*model.PluginInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAPPAllPlugins", ctx, appID) - ret0, _ := ret[0].([]*plugin.PluginInfo) + ret0, _ := ret[0].([]*model.PluginInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -136,41 +92,11 @@ func (mr *MockPluginServiceMockRecorder) GetAPPAllPlugins(ctx, appID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPPAllPlugins", reflect.TypeOf((*MockPluginService)(nil).GetAPPAllPlugins), ctx, appID) } -// GetPluginInvokableTools mocks base method. -func (m *MockPluginService) GetPluginInvokableTools(ctx context.Context, req *plugin.ToolsInvokableRequest) (map[int64]plugin0.InvokableTool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPluginInvokableTools", ctx, req) - ret0, _ := ret[0].(map[int64]plugin0.InvokableTool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetPluginInvokableTools indicates an expected call of GetPluginInvokableTools. -func (mr *MockPluginServiceMockRecorder) GetPluginInvokableTools(ctx, req any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginInvokableTools", reflect.TypeOf((*MockPluginService)(nil).GetPluginInvokableTools), ctx, req) -} - -// GetPluginToolsInfo mocks base method. -func (m *MockPluginService) GetPluginToolsInfo(ctx context.Context, req *plugin.ToolsInfoRequest) (*plugin.ToolsInfoResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPluginToolsInfo", ctx, req) - ret0, _ := ret[0].(*plugin.ToolsInfoResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetPluginToolsInfo indicates an expected call of GetPluginToolsInfo. -func (mr *MockPluginServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockPluginService)(nil).GetPluginToolsInfo), ctx, req) -} - // MGetAgentTools mocks base method. -func (m *MockPluginService) MGetAgentTools(ctx context.Context, req *plugin.MGetAgentToolsRequest) ([]*plugin.ToolInfo, error) { +func (m *MockPluginService) MGetAgentTools(ctx context.Context, req *model.MGetAgentToolsRequest) ([]*model.ToolInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetAgentTools", ctx, req) - ret0, _ := ret[0].([]*plugin.ToolInfo) + ret0, _ := ret[0].([]*model.ToolInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -181,26 +107,71 @@ func (mr *MockPluginServiceMockRecorder) MGetAgentTools(ctx, req any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetAgentTools", reflect.TypeOf((*MockPluginService)(nil).MGetAgentTools), ctx, req) } -// MGetPluginLatestVersion mocks base method. -func (m *MockPluginService) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (*plugin.MGetPluginLatestVersionResponse, error) { +// MGetDraftPlugins mocks base method. +func (m *MockPluginService) MGetDraftPlugins(ctx context.Context, pluginIDs []int64) ([]*model.PluginInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MGetPluginLatestVersion", ctx, pluginIDs) - ret0, _ := ret[0].(*plugin.MGetPluginLatestVersionResponse) + ret := m.ctrl.Call(m, "MGetDraftPlugins", ctx, pluginIDs) + ret0, _ := ret[0].([]*model.PluginInfo) ret1, _ := ret[1].(error) return ret0, ret1 } -// MGetPluginLatestVersion indicates an expected call of MGetPluginLatestVersion. -func (mr *MockPluginServiceMockRecorder) MGetPluginLatestVersion(ctx, pluginIDs any) *gomock.Call { +// MGetDraftPlugins indicates an expected call of MGetDraftPlugins. +func (mr *MockPluginServiceMockRecorder) MGetDraftPlugins(ctx, pluginIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetPluginLatestVersion", reflect.TypeOf((*MockPluginService)(nil).MGetPluginLatestVersion), ctx, pluginIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetDraftPlugins", reflect.TypeOf((*MockPluginService)(nil).MGetDraftPlugins), ctx, pluginIDs) +} + +// MGetDraftTools mocks base method. +func (m *MockPluginService) MGetDraftTools(ctx context.Context, pluginIDs []int64) ([]*model.ToolInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MGetDraftTools", ctx, pluginIDs) + ret0, _ := ret[0].([]*model.ToolInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetDraftTools indicates an expected call of MGetDraftTools. +func (mr *MockPluginServiceMockRecorder) MGetDraftTools(ctx, pluginIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetDraftTools", reflect.TypeOf((*MockPluginService)(nil).MGetDraftTools), ctx, pluginIDs) +} + +// MGetOnlinePlugins mocks base method. +func (m *MockPluginService) MGetOnlinePlugins(ctx context.Context, pluginIDs []int64) ([]*model.PluginInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MGetOnlinePlugins", ctx, pluginIDs) + ret0, _ := ret[0].([]*model.PluginInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetOnlinePlugins indicates an expected call of MGetOnlinePlugins. +func (mr *MockPluginServiceMockRecorder) MGetOnlinePlugins(ctx, pluginIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetOnlinePlugins", reflect.TypeOf((*MockPluginService)(nil).MGetOnlinePlugins), ctx, pluginIDs) +} + +// MGetOnlineTools mocks base method. +func (m *MockPluginService) MGetOnlineTools(ctx context.Context, pluginIDs []int64) ([]*model.ToolInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MGetOnlineTools", ctx, pluginIDs) + ret0, _ := ret[0].([]*model.ToolInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetOnlineTools indicates an expected call of MGetOnlineTools. +func (mr *MockPluginServiceMockRecorder) MGetOnlineTools(ctx, pluginIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetOnlineTools", reflect.TypeOf((*MockPluginService)(nil).MGetOnlineTools), ctx, pluginIDs) } // MGetVersionPlugins mocks base method. -func (m *MockPluginService) MGetVersionPlugins(ctx context.Context, versionPlugins []plugin.VersionPlugin) ([]*plugin.PluginInfo, error) { +func (m *MockPluginService) MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) ([]*model.PluginInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetVersionPlugins", ctx, versionPlugins) - ret0, _ := ret[0].([]*plugin.PluginInfo) + ret0, _ := ret[0].([]*model.PluginInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -212,10 +183,10 @@ func (mr *MockPluginServiceMockRecorder) MGetVersionPlugins(ctx, versionPlugins } // MGetVersionTools mocks base method. -func (m *MockPluginService) MGetVersionTools(ctx context.Context, versionTools []plugin.VersionTool) ([]*plugin.ToolInfo, error) { +func (m *MockPluginService) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) ([]*model.ToolInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetVersionTools", ctx, versionTools) - ret0, _ := ret[0].([]*plugin.ToolInfo) + ret0, _ := ret[0].([]*model.ToolInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -227,10 +198,10 @@ func (mr *MockPluginServiceMockRecorder) MGetVersionTools(ctx, versionTools any) } // PublishAPPPlugins mocks base method. -func (m *MockPluginService) PublishAPPPlugins(ctx context.Context, req *plugin.PublishAPPPluginsRequest) (*plugin.PublishAPPPluginsResponse, error) { +func (m *MockPluginService) PublishAPPPlugins(ctx context.Context, req *model.PublishAPPPluginsRequest) (*model.PublishAPPPluginsResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PublishAPPPlugins", ctx, req) - ret0, _ := ret[0].(*plugin.PublishAPPPluginsResponse) + ret0, _ := ret[0].(*model.PublishAPPPluginsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -241,34 +212,6 @@ func (mr *MockPluginServiceMockRecorder) PublishAPPPlugins(ctx, req any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAPPPlugins", reflect.TypeOf((*MockPluginService)(nil).PublishAPPPlugins), ctx, req) } -// PublishAgentTools mocks base method. -func (m *MockPluginService) PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishAgentTools", ctx, agentID, agentVersion) - ret0, _ := ret[0].(error) - return ret0 -} - -// PublishAgentTools indicates an expected call of PublishAgentTools. -func (mr *MockPluginServiceMockRecorder) PublishAgentTools(ctx, agentID, agentVersion any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAgentTools", reflect.TypeOf((*MockPluginService)(nil).PublishAgentTools), ctx, agentID, agentVersion) -} - -// PublishPlugin mocks base method. -func (m *MockPluginService) PublishPlugin(ctx context.Context, req *plugin.PublishPluginRequest) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishPlugin", ctx, req) - ret0, _ := ret[0].(error) - return ret0 -} - -// PublishPlugin indicates an expected call of PublishPlugin. -func (mr *MockPluginServiceMockRecorder) PublishPlugin(ctx, req any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishPlugin", reflect.TypeOf((*MockPluginService)(nil).PublishPlugin), ctx, req) -} - // MockInvokableTool is a mock of InvokableTool interface. type MockInvokableTool struct { ctrl *gomock.Controller diff --git a/backend/crossdomain/impl/plugin/plugin.go b/backend/crossdomain/impl/plugin/plugin.go index 4f37ce77e..c25b18db6 100644 --- a/backend/crossdomain/impl/plugin/plugin.go +++ b/backend/crossdomain/impl/plugin/plugin.go @@ -18,33 +18,13 @@ package plugin import ( "context" - "fmt" - "strconv" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" - "github.com/getkin/kin-openapi/openapi3" - "golang.org/x/exp/maps" - - workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" - "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow" - "github.com/coze-dev/coze-studio/backend/application/base/pluginutil" crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" - "github.com/coze-dev/coze-studio/backend/domain/plugin/service" plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service" - "github.com/coze-dev/coze-studio/backend/domain/workflow" - entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" - "github.com/coze-dev/coze-studio/backend/pkg/errorx" - "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" - "github.com/coze-dev/coze-studio/backend/pkg/sonic" - "github.com/coze-dev/coze-studio/backend/types/errno" ) var defaultSVC crossplugin.PluginService @@ -63,27 +43,10 @@ func InitDomainService(c plugin.PluginService, tos storage.Storage) crossplugin. return defaultSVC } -func (s *impl) MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (mPlugins []*model.PluginInfo, err error) { - plugins, err := s.DomainSVC.MGetVersionPlugins(ctx, versionPlugins) - if err != nil { - return nil, err - } - - mPlugins = slices.Transform(plugins, func(e *entity.PluginInfo) *model.PluginInfo { - return e.PluginInfo - }) - - return mPlugins, nil -} - func (s *impl) BindAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) { return s.DomainSVC.BindAgentTools(ctx, agentID, toolIDs) } -func (s *impl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) { - return s.DomainSVC.DuplicateDraftAgentTools(ctx, fromAgentID, toAgentID) -} - func (s *impl) MGetAgentTools(ctx context.Context, req *model.MGetAgentToolsRequest) (tools []*model.ToolInfo, err error) { return s.DomainSVC.MGetAgentTools(ctx, req) } @@ -92,30 +55,10 @@ func (s *impl) ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, o return s.DomainSVC.ExecuteTool(ctx, req, opts...) } -func (s *impl) PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) (err error) { - return s.DomainSVC.PublishAgentTools(ctx, agentID, agentVersion) -} - -func (s *impl) DeleteDraftPlugin(ctx context.Context, pluginID int64) (err error) { - return s.DomainSVC.DeleteDraftPlugin(ctx, pluginID) -} - -func (s *impl) PublishPlugin(ctx context.Context, req *model.PublishPluginRequest) (err error) { - return s.DomainSVC.PublishPlugin(ctx, req) -} - func (s *impl) PublishAPPPlugins(ctx context.Context, req *model.PublishAPPPluginsRequest) (resp *model.PublishAPPPluginsResponse, err error) { return s.DomainSVC.PublishAPPPlugins(ctx, req) } -func (s *impl) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *model.MGetPluginLatestVersionResponse, err error) { - return s.DomainSVC.MGetPluginLatestVersion(ctx, pluginIDs) -} - -func (s *impl) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*model.ToolInfo, err error) { - return s.DomainSVC.MGetVersionTools(ctx, versionTools) -} - func (s *impl) GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*model.PluginInfo, err error) { _plugins, err := s.DomainSVC.GetAPPAllPlugins(ctx, appID) if err != nil { @@ -129,476 +72,53 @@ func (s *impl) GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*mo return plugins, nil } -type pluginInfo struct { - *entity.PluginInfo - LatestVersion *string -} - -func (s *impl) getPluginsWithTools(ctx context.Context, pluginEntity *model.PluginEntity, toolIDs []int64, isDraft bool) ( - _ *pluginInfo, toolsInfo []*entity.ToolInfo, err error) { - defer func() { - if err != nil { - err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) - } - }() - - var pluginsInfo []*entity.PluginInfo - var latestPluginInfo *entity.PluginInfo - pluginID := pluginEntity.PluginID - if isDraft { - plugins, err := s.DomainSVC.MGetDraftPlugins(ctx, []int64{pluginID}) - if err != nil { - return nil, nil, err - } - pluginsInfo = plugins - } else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") { - plugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, []int64{pluginID}) - if err != nil { - return nil, nil, err - } - pluginsInfo = plugins - - } else { - plugins, err := s.DomainSVC.MGetVersionPlugins(ctx, []entity.VersionPlugin{ - {PluginID: pluginID, Version: *pluginEntity.PluginVersion}, - }) - if err != nil { - return nil, nil, err - } - pluginsInfo = plugins - - onlinePlugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, []int64{pluginID}) - if err != nil { - return nil, nil, err - } - for _, pi := range onlinePlugins { - if pi.ID == pluginID { - latestPluginInfo = pi - break - } - } - } - - var pInfo *entity.PluginInfo - for _, p := range pluginsInfo { - if p.ID == pluginID { - pInfo = p - break - } - } - if pInfo == nil { - return nil, nil, vo.NewError(errno.ErrPluginIDNotFound, errorx.KV("id", strconv.FormatInt(pluginID, 10))) - } - - if isDraft { - tools, err := s.DomainSVC.MGetDraftTools(ctx, toolIDs) - if err != nil { - return nil, nil, err - } - toolsInfo = tools - } else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") { - tools, err := s.DomainSVC.MGetOnlineTools(ctx, toolIDs) - if err != nil { - return nil, nil, err - } - toolsInfo = tools - } else { - eVersionTools := slices.Transform(toolIDs, func(tid int64) entity.VersionTool { - return entity.VersionTool{ - ToolID: tid, - Version: *pluginEntity.PluginVersion, - } - }) - tools, err := s.DomainSVC.MGetVersionTools(ctx, eVersionTools) - if err != nil { - return nil, nil, err - } - toolsInfo = tools - } - - if latestPluginInfo != nil { - return &pluginInfo{PluginInfo: pInfo, LatestVersion: latestPluginInfo.Version}, toolsInfo, nil - } - - return &pluginInfo{PluginInfo: pInfo}, toolsInfo, nil -} - -func (s *impl) GetPluginToolsInfo(ctx context.Context, req *model.ToolsInfoRequest) ( - _ *model.ToolsInfoResponse, err error) { - defer func() { - if err != nil { - err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) - } - }() - - var toolsInfo []*entity.ToolInfo - isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0") - pInfo, toolsInfo, err := s.getPluginsWithTools(ctx, &model.PluginEntity{PluginID: req.PluginEntity.PluginID, PluginVersion: req.PluginEntity.PluginVersion}, req.ToolIDs, isDraft) +func (s *impl) MGetDraftPlugins(ctx context.Context, pluginIDs []int64) (plugins []*model.PluginInfo, err error) { + ePlugins, err := s.DomainSVC.MGetDraftPlugins(ctx, pluginIDs) if err != nil { return nil, err } - url, err := s.tos.GetObjectUrl(ctx, pInfo.GetIconURI()) - if err != nil { - return nil, vo.WrapIfNeeded(errno.ErrTOSError, err) - } + plugins = slices.Transform(ePlugins, func(e *entity.PluginInfo) *model.PluginInfo { + return e.PluginInfo + }) - response := &model.ToolsInfoResponse{ - PluginID: pInfo.ID, - SpaceID: pInfo.SpaceID, - Version: pInfo.GetVersion(), - PluginName: pInfo.GetName(), - Description: pInfo.GetDesc(), - IconURL: url, - PluginType: int64(pInfo.PluginType), - ToolInfoList: make(map[int64]model.ToolInfoW), - LatestVersion: pInfo.LatestVersion, - IsOfficial: pInfo.IsOfficial(), - AppID: pInfo.GetAPPID(), - } - - for _, tf := range toolsInfo { - inputs, err := tf.ToReqAPIParameter() - if err != nil { - return nil, err - } - outputs, err := tf.ToRespAPIParameter() - if err != nil { - return nil, err - } - toolExample := pInfo.GetToolExample(ctx, tf.GetName()) - - var ( - requestExample string - responseExample string - ) - if toolExample != nil { - requestExample = toolExample.RequestExample - responseExample = toolExample.ResponseExample - } - - response.ToolInfoList[tf.ID] = model.ToolInfoW{ - ToolID: tf.ID, - ToolName: tf.GetName(), - Inputs: slices.Transform(inputs, toWorkflowAPIParameter), - Outputs: slices.Transform(outputs, toWorkflowAPIParameter), - Description: tf.GetDesc(), - DebugExample: &model.DebugExample{ - ReqExample: requestExample, - RespExample: responseExample, - }, - } - - } - return response, nil + return plugins, nil } -func (s *impl) GetPluginInvokableTools(ctx context.Context, req *model.ToolsInvokableRequest) ( - _ map[int64]crossplugin.InvokableTool, err error) { - defer func() { - if err != nil { - err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) - } - }() - - var toolsInfo []*entity.ToolInfo - isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0") - pInfo, toolsInfo, err := s.getPluginsWithTools(ctx, &model.PluginEntity{ - PluginID: req.PluginEntity.PluginID, - PluginVersion: req.PluginEntity.PluginVersion, - }, maps.Keys(req.ToolsInvokableInfo), isDraft) +func (s *impl) MGetOnlinePlugins(ctx context.Context, pluginIDs []int64) (plugins []*model.PluginInfo, err error) { + ePlugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, pluginIDs) if err != nil { return nil, err } - result := map[int64]crossplugin.InvokableTool{} - for _, tf := range toolsInfo { - tl := &pluginInvokeTool{ - pluginEntity: model.PluginEntity{ - PluginID: pInfo.ID, - PluginVersion: pInfo.Version, - }, - client: s.DomainSVC, - toolInfo: tf, - IsDraft: isDraft, - } + plugins = slices.Transform(ePlugins, func(e *entity.PluginInfo) *model.PluginInfo { + return e.PluginInfo + }) - if r, ok := req.ToolsInvokableInfo[tf.ID]; ok && (r.RequestAPIParametersConfig != nil && r.ResponseAPIParametersConfig != nil) { - reqPluginCommonAPIParameters := slices.Transform(r.RequestAPIParametersConfig, toPluginCommonAPIParameter) - respPluginCommonAPIParameters := slices.Transform(r.ResponseAPIParametersConfig, toPluginCommonAPIParameter) - - tl.toolOperation, err = pluginutil.APIParamsToOpenapiOperation(reqPluginCommonAPIParameters, respPluginCommonAPIParameters) - if err != nil { - return nil, err - } - - tl.toolOperation.OperationID = tf.Operation.OperationID - tl.toolOperation.Summary = tf.Operation.Summary - } - - result[tf.ID] = tl - } - return result, nil + return plugins, nil } -func (s *impl) ExecutePlugin(ctx context.Context, input map[string]any, pe *model.PluginEntity, - toolID int64, cfg workflowModel.ExecuteConfig) (map[string]any, error) { - args, err := sonic.MarshalString(input) - if err != nil { - return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err) - } - - var uID string - if cfg.AgentID != nil { - uID = cfg.ConnectorUID - } else { - uID = conv.Int64ToStr(cfg.Operator) - } - - req := &model.ExecuteToolRequest{ - UserID: uID, - PluginID: pe.PluginID, - ToolID: toolID, - ExecScene: model.ExecSceneOfWorkflow, - ArgumentsInJson: args, - ExecDraftTool: pe.PluginVersion == nil || *pe.PluginVersion == "0", - } - execOpts := []entity.ExecuteToolOpt{ - model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault), - } - - if pe.PluginVersion != nil { - execOpts = append(execOpts, model.WithToolVersion(*pe.PluginVersion)) - } - - r, err := s.DomainSVC.ExecuteTool(ctx, req, execOpts...) - if err != nil { - if extra, ok := compose.IsInterruptRerunError(err); ok { - pluginTIE, ok := extra.(*model.ToolInterruptEvent) - if !ok { - return nil, vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra)) - } - - var eventType workflow3.EventType - switch pluginTIE.Event { - case model.InterruptEventTypeOfToolNeedOAuth: - eventType = workflow3.EventType_WorkflowOauthPlugin - default: - return nil, vo.WrapError(errno.ErrPluginAPIErr, - fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event)) - } - - id, err := workflow.GetRepository().GenID(ctx) - if err != nil { - return nil, vo.WrapError(errno.ErrIDGenError, err) - } - - ie := &entity2.InterruptEvent{ - ID: id, - InterruptData: pluginTIE.ToolNeedOAuth.Message, - EventType: eventType, - } - - // temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt - interruptData := ie.InterruptData - return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData)) - } - return nil, err - } - - var output map[string]any - err = sonic.UnmarshalString(r.TrimmedResp, &output) - if err != nil { - return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err) - } - - return output, nil -} - -type pluginInvokeTool struct { - pluginEntity model.PluginEntity - client service.PluginService - toolInfo *entity.ToolInfo - toolOperation *openapi3.Operation - IsDraft bool -} - -func (p *pluginInvokeTool) Info(ctx context.Context) (_ *schema.ToolInfo, err error) { - defer func() { - if err != nil { - err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) - } - }() - - var parameterInfo map[string]*schema.ParameterInfo - if p.toolOperation != nil { - parameterInfo, err = model.NewOpenapi3Operation(p.toolOperation).ToEinoSchemaParameterInfo(ctx) - } else { - parameterInfo, err = p.toolInfo.Operation.ToEinoSchemaParameterInfo(ctx) - } - +func (s *impl) MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*model.PluginInfo, err error) { + ePlugins, err := s.DomainSVC.MGetVersionPlugins(ctx, versionPlugins) if err != nil { return nil, err } - return &schema.ToolInfo{ - Name: p.toolInfo.GetName(), - Desc: p.toolInfo.GetDesc(), - ParamsOneOf: schema.NewParamsOneOfByParams(parameterInfo), - }, nil + plugins = slices.Transform(ePlugins, func(e *entity.PluginInfo) *model.PluginInfo { + return e.PluginInfo + }) + + return plugins, nil } -func (p *pluginInvokeTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflowModel.ExecuteConfig) (string, error) { - req := &model.ExecuteToolRequest{ - UserID: conv.Int64ToStr(cfg.Operator), - PluginID: p.pluginEntity.PluginID, - ToolID: p.toolInfo.ID, - ExecScene: model.ExecSceneOfWorkflow, - ArgumentsInJson: argumentsInJSON, - ExecDraftTool: p.IsDraft, - } - execOpts := []entity.ExecuteToolOpt{ - model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault), - } - - if p.pluginEntity.PluginVersion != nil { - execOpts = append(execOpts, model.WithToolVersion(*p.pluginEntity.PluginVersion)) - } - - if p.toolOperation != nil { - execOpts = append(execOpts, model.WithOpenapiOperation(model.NewOpenapi3Operation(p.toolOperation))) - } - - r, err := p.client.ExecuteTool(ctx, req, execOpts...) - if err != nil { - if extra, ok := compose.IsInterruptRerunError(err); ok { - pluginTIE, ok := extra.(*model.ToolInterruptEvent) - if !ok { - return "", vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra)) - } - - var eventType workflow3.EventType - switch pluginTIE.Event { - case model.InterruptEventTypeOfToolNeedOAuth: - eventType = workflow3.EventType_WorkflowOauthPlugin - default: - return "", vo.WrapError(errno.ErrPluginAPIErr, - fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event)) - } - - id, err := workflow.GetRepository().GenID(ctx) - if err != nil { - return "", vo.WrapError(errno.ErrIDGenError, err) - } - - ie := &entity2.InterruptEvent{ - ID: id, - InterruptData: pluginTIE.ToolNeedOAuth.Message, - EventType: eventType, - } - - tie := &entity2.ToolInterruptEvent{ - ToolCallID: compose.GetToolCallID(ctx), - ToolName: p.toolInfo.GetName(), - InterruptEvent: ie, - } - - // temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt - _ = tie - interruptData := ie.InterruptData - return "", vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData)) - } - return "", err - } - return r.TrimmedResp, nil +func (s *impl) MGetDraftTools(ctx context.Context, pluginIDs []int64) (tools []*model.ToolInfo, err error) { + return s.DomainSVC.MGetDraftTools(ctx, pluginIDs) } -func toPluginCommonAPIParameter(parameter *workflow3.APIParameter) *common.APIParameter { - if parameter == nil { - return nil - } - p := &common.APIParameter{ - ID: parameter.ID, - Name: parameter.Name, - Desc: parameter.Desc, - Type: common.ParameterType(parameter.Type), - Location: common.ParameterLocation(parameter.Location), - IsRequired: parameter.IsRequired, - GlobalDefault: parameter.GlobalDefault, - GlobalDisable: parameter.GlobalDisable, - LocalDefault: parameter.LocalDefault, - LocalDisable: parameter.LocalDisable, - VariableRef: parameter.VariableRef, - } - if parameter.SubType != nil { - p.SubType = ptr.Of(common.ParameterType(*parameter.SubType)) - } - - if parameter.DefaultParamSource != nil { - p.DefaultParamSource = ptr.Of(common.DefaultParamSource(*parameter.DefaultParamSource)) - } - if parameter.AssistType != nil { - p.AssistType = ptr.Of(common.AssistParameterType(*parameter.AssistType)) - } - - if len(parameter.SubParameters) > 0 { - p.SubParameters = make([]*common.APIParameter, 0, len(parameter.SubParameters)) - for _, subParam := range parameter.SubParameters { - p.SubParameters = append(p.SubParameters, toPluginCommonAPIParameter(subParam)) - } - } - - return p +func (s *impl) MGetOnlineTools(ctx context.Context, pluginIDs []int64) (tools []*model.ToolInfo, err error) { + return s.DomainSVC.MGetOnlineTools(ctx, pluginIDs) } -func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParameter { - if parameter == nil { - return nil - } - p := &workflow3.APIParameter{ - ID: parameter.ID, - Name: parameter.Name, - Desc: parameter.Desc, - Type: workflow3.ParameterType(parameter.Type), - Location: workflow3.ParameterLocation(parameter.Location), - IsRequired: parameter.IsRequired, - GlobalDefault: parameter.GlobalDefault, - GlobalDisable: parameter.GlobalDisable, - LocalDefault: parameter.LocalDefault, - LocalDisable: parameter.LocalDisable, - VariableRef: parameter.VariableRef, - } - if parameter.SubType != nil { - p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType)) - } - if parameter.DefaultParamSource != nil { - p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource)) - } - if parameter.AssistType != nil { - p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType)) - } - - // Check if it's a specially wrapped array that needs unwrapping. - if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" { - arrayItem := parameter.SubParameters[0] - // The actual type of array elements is the type of the "[Array Item]". - p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type)) - // If the array elements are objects, their sub-parameters (fields) are lifted up. - if arrayItem.Type == common.ParameterType_Object { - p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters)) - for _, subParam := range arrayItem.SubParameters { - p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam)) - } - } else { - p.SubParameters = make([]*workflow3.APIParameter, 0, 1) - p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem)) - } - } else if len(parameter.SubParameters) > 0 { - p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters)) - for _, subParam := range parameter.SubParameters { - p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam)) - } - } - - return p +func (s *impl) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*model.ToolInfo, err error) { + return s.DomainSVC.MGetVersionTools(ctx, versionTools) } diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner_test.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner_test.go index c74938390..32667ae7d 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner_test.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner_test.go @@ -101,7 +101,7 @@ func TestAgentRunner_preHandlerInput(t *testing.T) { }, }, expectedResult: &schema.Message{ - Role: schema.User, + Role: schema.User, Content: "", MultiContent: []schema.ChatMessagePart{ { @@ -350,16 +350,16 @@ func TestAgentRunner_preHandlerInput(t *testing.T) { func TestAgentRunner_concatContentString(t *testing.T) { tests := []struct { - name string - textContent string - unSupportTypeURL []schema.ChatMessagePart - expectedResult string + name string + textContent string + unSupportTypeURL []schema.ChatMessagePart + expectedResult string }{ { - name: "empty unsupported types should return original text", - textContent: "original text", - unSupportTypeURL: []schema.ChatMessagePart{}, - expectedResult: "original text", + name: "empty unsupported types should return original text", + textContent: "original text", + unSupportTypeURL: []schema.ChatMessagePart{}, + expectedResult: "original text", }, { name: "single image URL should be appended", @@ -424,4 +424,4 @@ func TestAgentRunner_concatContentString(t *testing.T) { assert.Equal(t, tt.expectedResult, result) }) } -} \ No newline at end of file +} diff --git a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go index d0ea1d597..575bf7077 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go +++ b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go @@ -30,7 +30,8 @@ import ( "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" @@ -259,7 +260,7 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType { case *crossworkflow.ToolInterruptEvent: interruptEventType = singleagent.InterruptEventType(int64(t.EventType)) case *plugin.ToolInterruptEvent: - if t.Event == plugin.InterruptEventTypeOfToolNeedOAuth { + if t.Event == consts.InterruptEventTypeOfToolNeedOAuth { interruptEventType = singleagent.InterruptEventType_OauthPlugin } } diff --git a/backend/domain/agent/singleagent/internal/agentflow/node_tool_plugin.go b/backend/domain/agent/singleagent/internal/agentflow/node_tool_plugin.go index 1d99351b0..70741f9c2 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/node_tool_plugin.go +++ b/backend/domain/agent/singleagent/internal/agentflow/node_tool_plugin.go @@ -25,7 +25,8 @@ import ( "github.com/coze-dev/coze-studio/backend/api/model/app/bot_common" crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -46,8 +47,8 @@ func newPluginTools(ctx context.Context, conf *toolConfig) ([]tool.InvokableTool SpaceID: conf.spaceID, AgentID: conf.agentIdentity.AgentID, IsDraft: conf.agentIdentity.IsDraft, - VersionAgentTools: slices.Transform(conf.toolConf, func(a *bot_common.PluginInfo) pluginEntity.VersionAgentTool { - return pluginEntity.VersionAgentTool{ + VersionAgentTools: slices.Transform(conf.toolConf, func(a *bot_common.PluginInfo) model.VersionAgentTool { + return model.VersionAgentTool{ ToolID: a.GetApiId(), AgentVersion: ptr.Of(conf.agentIdentity.Version), } @@ -60,7 +61,7 @@ func newPluginTools(ctx context.Context, conf *toolConfig) ([]tool.InvokableTool projectInfo := &model.ProjectInfo{ ProjectID: conf.agentIdentity.AgentID, - ProjectType: model.ProjectTypeOfAgent, + ProjectType: consts.ProjectTypeOfAgent, ProjectVersion: ptr.Of(conf.agentIdentity.Version), ConnectorID: conf.agentIdentity.ConnectorID, } @@ -117,16 +118,16 @@ func (p *pluginInvokableTool) InvokableRun(ctx context.Context, argumentsInJSON ToolID: p.toolInfo.ID, ExecDraftTool: false, ArgumentsInJson: argumentsInJSON, - ExecScene: func() model.ExecuteScene { + ExecScene: func() consts.ExecuteScene { if p.isDraft { - return model.ExecSceneOfDraftAgent + return consts.ExecSceneOfDraftAgent } - return model.ExecSceneOfOnlineAgent + return consts.ExecSceneOfOnlineAgent }(), } - opts := []pluginEntity.ExecuteToolOpt{ - model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault), + opts := []model.ExecuteToolOpt{ + model.WithInvalidRespProcessStrategy(consts.InvalidResponseProcessStrategyOfReturnDefault), model.WithToolVersion(p.toolInfo.GetVersion()), model.WithProjectInfo(p.projectInfo), model.WithPluginHTTPHeader(p.conversationID), diff --git a/backend/domain/agent/singleagent/internal/agentflow/node_tool_pre_retriever.go b/backend/domain/agent/singleagent/internal/agentflow/node_tool_pre_retriever.go index 61f723f97..8fb55cdac 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/node_tool_pre_retriever.go +++ b/backend/domain/agent/singleagent/internal/agentflow/node_tool_pre_retriever.go @@ -26,9 +26,9 @@ import ( "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun" workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow" - pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/logs" ) @@ -57,20 +57,20 @@ func (pr *toolPreCallConf) toolPreRetrieve(ctx context.Context, ar *AgentRequest PluginID: item.PluginID, ToolID: item.ToolID, ArgumentsInJson: item.Arguments, - ExecScene: func(isDraft bool) model.ExecuteScene { + ExecScene: func(isDraft bool) consts.ExecuteScene { if isDraft { - return model.ExecSceneOfDraftAgent + return consts.ExecSceneOfDraftAgent } else { - return model.ExecSceneOfOnlineAgent + return consts.ExecSceneOfOnlineAgent } }(ar.Identity.IsDraft), } - opts := []pluginEntity.ExecuteToolOpt{ - model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault), + opts := []model.ExecuteToolOpt{ + model.WithInvalidRespProcessStrategy(consts.InvalidResponseProcessStrategyOfReturnDefault), model.WithProjectInfo(&model.ProjectInfo{ ProjectID: ar.Identity.AgentID, - ProjectType: model.ProjectTypeOfAgent, + ProjectType: consts.ProjectTypeOfAgent, ProjectVersion: ptr.Of(ar.Identity.Version), }), } diff --git a/backend/domain/agent/singleagent/service/publish.go b/backend/domain/agent/singleagent/service/publish.go index be4f4e216..91464a81c 100644 --- a/backend/domain/agent/singleagent/service/publish.go +++ b/backend/domain/agent/singleagent/service/publish.go @@ -125,9 +125,6 @@ func (s *singleAgentImpl) GetPublishConnectorList(ctx context.Context, agentID i c.BindType = developer_api.BindType_WebSDKBind } else if v.ID == consts.APIConnectorID { c.BindType = developer_api.BindType_ApiBind - // c.BindInfo = map[string]string{ - // "sdk_version": "1.2.0 -beta.6",//TODO (@fanlv): Where to check the version? - // } c.AuthLoginInfo = &developer_api.AuthLoginInfo{} } diff --git a/backend/domain/app/service/publish_app.go b/backend/domain/app/service/publish_app.go index c746e7cbb..71e725985 100644 --- a/backend/domain/app/service/publish_app.go +++ b/backend/domain/app/service/publish_app.go @@ -22,7 +22,7 @@ import ( resourceCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow" "github.com/coze-dev/coze-studio/backend/domain/app/entity" "github.com/coze-dev/coze-studio/backend/domain/app/repository" diff --git a/backend/domain/plugin/conf/load_plugin.go b/backend/domain/plugin/conf/load_plugin.go index 4559c5ea9..63306a263 100644 --- a/backend/domain/plugin/conf/load_plugin.go +++ b/backend/domain/plugin/conf/load_plugin.go @@ -29,21 +29,23 @@ import ( "gopkg.in/yaml.v3" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/logs" ) type pluginProductMeta struct { - PluginID int64 `yaml:"plugin_id" validate:"required"` - ProductID int64 `yaml:"product_id" validate:"required"` - Deprecated bool `yaml:"deprecated"` - Version string `yaml:"version" validate:"required"` - PluginType common.PluginType `yaml:"plugin_type" validate:"required"` - OpenapiDocFile string `yaml:"openapi_doc_file" validate:"required"` - Manifest *entity.PluginManifest `yaml:"manifest" validate:"required"` - Tools []*toolProductMeta `yaml:"tools" validate:"required"` + PluginID int64 `yaml:"plugin_id" validate:"required"` + ProductID int64 `yaml:"product_id" validate:"required"` + Deprecated bool `yaml:"deprecated"` + Version string `yaml:"version" validate:"required"` + PluginType common.PluginType `yaml:"plugin_type" validate:"required"` + OpenapiDocFile string `yaml:"openapi_doc_file" validate:"required"` + Manifest *model.PluginManifest `yaml:"manifest" validate:"required"` + Tools []*toolProductMeta `yaml:"tools" validate:"required"` } type toolProductMeta struct { @@ -195,10 +197,10 @@ func loadPluginProductMeta(ctx context.Context, basePath string) (err error) { pluginProducts[m.PluginID] = pi - apis := make(map[entity.UniqueToolAPI]*model.Openapi3Operation, len(doc.Paths)) + apis := make(map[dto.UniqueToolAPI]*model.Openapi3Operation, len(doc.Paths)) for subURL, pathItem := range doc.Paths { for method, op := range pathItem.Operations() { - api := entity.UniqueToolAPI{ + api := dto.UniqueToolAPI{ SubURL: subURL, Method: strings.ToUpper(method), } @@ -217,7 +219,7 @@ func loadPluginProductMeta(ctx context.Context, basePath string) (err error) { continue } - api := entity.UniqueToolAPI{ + api := dto.UniqueToolAPI{ SubURL: t.SubURL, Method: strings.ToUpper(t.Method), } @@ -242,7 +244,7 @@ func loadPluginProductMeta(ctx context.Context, basePath string) (err error) { Method: ptr.Of(t.Method), SubURL: ptr.Of(t.SubURL), Operation: op, - ActivatedStatus: ptr.Of(model.ActivateTool), + ActivatedStatus: ptr.Of(consts.ActivateTool), DebugStatus: ptr.Of(common.APIDebugStatus_DebugPassed), }, } diff --git a/backend/domain/plugin/dto/auth.go b/backend/domain/plugin/dto/auth.go index 611d834fe..91bd8e10c 100644 --- a/backend/domain/plugin/dto/auth.go +++ b/backend/domain/plugin/dto/auth.go @@ -18,8 +18,9 @@ package dto import ( "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) type GetOAuthStatusResponse struct { @@ -38,6 +39,52 @@ type AgentPluginOAuthStatus struct { type GetAccessTokenRequest struct { UserID string PluginID *int64 - Mode model.AuthzSubType - OAuthInfo *entity.OAuthInfo + Mode consts.AuthzSubType + OAuthInfo *OAuthInfo +} + +type PluginAuthInfo struct { + AuthzType *consts.AuthzType + Location *consts.HTTPParamLocation + Key *string + ServiceToken *string + OAuthInfo *string + AuthzSubType *consts.AuthzSubType + AuthzPayload *string +} + +type OAuthInfo struct { + OAuthMode consts.AuthzSubType + AuthorizationCode *AuthorizationCodeInfo +} + +type OAuthState struct { + ClientName OAuthProvider `json:"client_name"` + UserID string `json:"user_id"` + PluginID int64 `json:"plugin_id"` + IsDraft bool `json:"is_draft"` +} + +type AuthorizationCodeMeta struct { + UserID string + PluginID int64 + IsDraft bool +} + +type AuthorizationCodeInfo struct { + RecordID int64 + Meta *AuthorizationCodeMeta + Config *model.OAuthAuthorizationCodeConfig + AccessToken string + RefreshToken string + TokenExpiredAtMS int64 + NextTokenRefreshAtMS *int64 + LastActiveAtMS int64 +} + +func (a *AuthorizationCodeInfo) GetNextTokenRefreshAtMS() int64 { + if a == nil { + return 0 + } + return ptr.FromOrDefault(a.NextTokenRefreshAtMS, 0) } diff --git a/backend/domain/plugin/dto/plugin.go b/backend/domain/plugin/dto/plugin.go index e30f2efcc..9f6c18cc0 100644 --- a/backend/domain/plugin/dto/plugin.go +++ b/backend/domain/plugin/dto/plugin.go @@ -18,7 +18,8 @@ package dto import ( "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" ) @@ -39,7 +40,7 @@ type UpdateDraftPluginWithCodeRequest struct { UserID int64 PluginID int64 OpenapiDoc *model.Openapi3T - Manifest *entity.PluginManifest + Manifest *model.PluginManifest } type UpdateDraftPluginRequest struct { @@ -55,9 +56,26 @@ type UpdateDraftPluginRequest struct { type ListDraftPluginsRequest struct { SpaceID int64 APPID int64 - PageInfo entity.PageInfo + PageInfo PageInfo } +type PageInfo struct { + Name *string + Page int + Size int + SortBy *SortField + OrderByACS *bool +} + +type SortField string + +const ( + SortByCreatedAt SortField = "created_at" + SortByUpdatedAt SortField = "updated_at" +) + +type OAuthProvider string + type ListDraftPluginsResponse struct { Plugins []*entity.PluginInfo Total int64 @@ -67,7 +85,7 @@ type CreateDraftPluginWithCodeRequest struct { SpaceID int64 DeveloperID int64 ProjectID *int64 - Manifest *entity.PluginManifest + Manifest *model.PluginManifest OpenapiDoc *model.Openapi3T } @@ -86,7 +104,7 @@ type ListPluginProductsResponse struct { type CopyPluginRequest struct { UserID int64 PluginID int64 - CopyScene model.CopyScene + CopyScene consts.CopyScene TargetAPPID *int64 } diff --git a/backend/domain/plugin/dto/tool.go b/backend/domain/plugin/dto/tool.go index 5447f5405..bfa86df8a 100644 --- a/backend/domain/plugin/dto/tool.go +++ b/backend/domain/plugin/dto/tool.go @@ -20,8 +20,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" ) type CreateDraftToolsWithCodeRequest struct { @@ -32,7 +31,7 @@ type CreateDraftToolsWithCodeRequest struct { } type CreateDraftToolsWithCodeResponse struct { - DuplicatedTools []entity.UniqueToolAPI + DuplicatedTools []UniqueToolAPI } type UpdateDraftToolRequest struct { @@ -58,7 +57,7 @@ type ConvertToOpenapi3DocRequest struct { type ConvertToOpenapi3DocResponse struct { OpenapiDoc *model.Openapi3T - Manifest *entity.PluginManifest + Manifest *model.PluginManifest Format common.PluginDataFormat ErrMsg string } @@ -71,3 +70,7 @@ type UpdateBotDefaultParamsRequest struct { RequestBody *openapi3.RequestBodyRef Responses openapi3.Responses } +type UniqueToolAPI struct { + SubURL string + Method string +} diff --git a/backend/domain/plugin/entity/consts.go b/backend/domain/plugin/entity/consts.go deleted file mode 100644 index c7bb92158..000000000 --- a/backend/domain/plugin/entity/consts.go +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package entity - -type SortField string - -const ( - SortByCreatedAt SortField = "created_at" - SortByUpdatedAt SortField = "updated_at" -) - -type OAuthProvider string diff --git a/backend/domain/plugin/entity/oauth.go b/backend/domain/plugin/entity/oauth.go deleted file mode 100644 index 829ab42ce..000000000 --- a/backend/domain/plugin/entity/oauth.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package entity - -import ( - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" -) - -type AuthorizationCodeMeta struct { - UserID string - PluginID int64 - IsDraft bool -} - -type AuthorizationCodeInfo struct { - RecordID int64 - Meta *AuthorizationCodeMeta - Config *model.OAuthAuthorizationCodeConfig - AccessToken string - RefreshToken string - TokenExpiredAtMS int64 - NextTokenRefreshAtMS *int64 - LastActiveAtMS int64 -} - -func (a *AuthorizationCodeInfo) GetNextTokenRefreshAtMS() int64 { - if a == nil { - return 0 - } - return ptr.FromOrDefault(a.NextTokenRefreshAtMS, 0) -} - -type OAuthInfo struct { - OAuthMode model.AuthzSubType - AuthorizationCode *AuthorizationCodeInfo -} - -type OAuthState struct { - ClientName OAuthProvider `json:"client_name"` - UserID string `json:"user_id"` - PluginID int64 `json:"plugin_id"` - IsDraft bool `json:"is_draft"` -} diff --git a/backend/domain/plugin/entity/plugin.go b/backend/domain/plugin/entity/plugin.go index 18f49fbc0..0757b6f29 100644 --- a/backend/domain/plugin/entity/plugin.go +++ b/backend/domain/plugin/entity/plugin.go @@ -17,17 +17,8 @@ package entity import ( - "context" - "net/http" - "strconv" - - "github.com/bytedance/sonic" - "github.com/getkin/kin-openapi/openapi3" - - "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" - "github.com/coze-dev/coze-studio/backend/pkg/logs" ) type PluginInfo struct { @@ -40,13 +31,13 @@ func NewPluginInfo(info *model.PluginInfo) *PluginInfo { } } -func NewPluginInfos(infos []*model.PluginInfo) []*PluginInfo { - res := make([]*PluginInfo, 0, len(infos)) - for _, info := range infos { - res = append(res, NewPluginInfo(info)) +func (p PluginInfo) SetName(name string) { + if p.Manifest == nil || p.OpenapiDoc == nil { + return } - - return res + p.Manifest.NameForModel = name + p.Manifest.NameForHuman = name + p.OpenapiDoc.Info.Title = name } func (p PluginInfo) GetServerURL() string { @@ -57,161 +48,6 @@ func (p PluginInfo) GetRefProductID() int64 { return ptr.FromOrDefault(p.RefProductID, 0) } -func (p PluginInfo) GetVersion() string { - return ptr.FromOrDefault(p.Version, "") -} - func (p PluginInfo) GetVersionDesc() string { return ptr.FromOrDefault(p.VersionDesc, "") } - -func (p PluginInfo) GetAPPID() int64 { - return ptr.FromOrDefault(p.APPID, 0) -} - -type ToolExample struct { - RequestExample string - ResponseExample string -} - -func (p PluginInfo) GetToolExample(ctx context.Context, toolName string) *ToolExample { - if p.OpenapiDoc == nil || - p.OpenapiDoc.Components == nil || - len(p.OpenapiDoc.Components.Examples) == 0 { - return nil - } - example, ok := p.OpenapiDoc.Components.Examples[toolName] - if !ok { - return nil - } - if example.Value == nil || example.Value.Value == nil { - return nil - } - - val, ok := example.Value.Value.(map[string]any) - if !ok { - return nil - } - - reqExample, ok := val["ReqExample"] - if !ok { - return nil - } - reqExampleStr, err := sonic.MarshalString(reqExample) - if err != nil { - logs.CtxErrorf(ctx, "marshal request example failed, err=%v", err) - return nil - } - - respExample, ok := val["RespExample"] - if !ok { - return nil - } - respExampleStr, err := sonic.MarshalString(respExample) - if err != nil { - logs.CtxErrorf(ctx, "marshal response example failed, err=%v", err) - return nil - } - - return &ToolExample{ - RequestExample: reqExampleStr, - ResponseExample: respExampleStr, - } -} - -type ToolInfo = model.ToolInfo - -type AgentToolIdentity struct { - ToolID int64 - ToolName *string - AgentID int64 - VersionMs *int64 -} - -type VersionTool = model.VersionTool - -type VersionPlugin = model.VersionPlugin - -type VersionAgentTool = model.VersionAgentTool - -type ExecuteToolOpt = model.ExecuteToolOpt - -type ProjectInfo = model.ProjectInfo - -type PluginManifest = model.PluginManifest - -// TODO API.DESC 来给不同 default 值 -func NewDefaultPluginManifest() *PluginManifest { - return &model.PluginManifest{ - SchemaVersion: "v1", - API: model.APIDesc{ - Type: model.PluginTypeOfCloud, - }, - Auth: &model.AuthV2{ - Type: model.AuthzTypeOfNone, - }, - CommonParams: map[model.HTTPParamLocation][]*common.CommonParamSchema{ - model.ParamInBody: {}, - model.ParamInHeader: { - { - Name: "User-Agent", - Value: "Coze/1.0", - }, - }, - model.ParamInQuery: {}, - }, - } -} - -func NewDefaultOpenapiDoc() *model.Openapi3T { - return &model.Openapi3T{ - OpenAPI: "3.0.1", - Info: &openapi3.Info{ - Version: "v1", - }, - Paths: openapi3.Paths{}, - Servers: openapi3.Servers{}, - } -} - -type UniqueToolAPI struct { - SubURL string - Method string -} - -func DefaultOpenapi3Responses() openapi3.Responses { - return openapi3.Responses{ - strconv.Itoa(http.StatusOK): { - Value: &openapi3.Response{ - Description: ptr.Of("description is required"), - Content: openapi3.Content{ - model.MediaTypeJson: &openapi3.MediaType{ - Schema: &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - Type: openapi3.TypeObject, - Properties: map[string]*openapi3.SchemaRef{}, - }, - }, - }, - }, - }, - }, - } -} - -func DefaultOpenapi3RequestBody() *openapi3.RequestBodyRef { - return &openapi3.RequestBodyRef{ - Value: &openapi3.RequestBody{ - Content: map[string]*openapi3.MediaType{ - model.MediaTypeJson: { - Schema: &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - Type: openapi3.TypeObject, - Properties: map[string]*openapi3.SchemaRef{}, - }, - }, - }, - }, - }, - } -} diff --git a/backend/domain/plugin/entity/common.go b/backend/domain/plugin/entity/tool.go similarity index 83% rename from backend/domain/plugin/entity/common.go rename to backend/domain/plugin/entity/tool.go index 51491b699..271914d2c 100644 --- a/backend/domain/plugin/entity/common.go +++ b/backend/domain/plugin/entity/tool.go @@ -16,10 +16,6 @@ package entity -type PageInfo struct { - Name *string - Page int - Size int - SortBy *SortField - OrderByACS *bool -} +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + +type ToolInfo = model.ToolInfo diff --git a/backend/domain/plugin/internal/dal/agent_tool_draft.go b/backend/domain/plugin/internal/dal/agent_tool_draft.go index bacb21a59..71feade58 100644 --- a/backend/domain/plugin/internal/dal/agent_tool_draft.go +++ b/backend/domain/plugin/internal/dal/agent_tool_draft.go @@ -206,9 +206,9 @@ func (at *AgentToolDraftDAO) batchCreateWithTX(ctx context.Context, tx *query.Qu tls := make([]*model.AgentToolDraft, 0, len(tools)) for _, tl := range tools { - id, err := at.idGen.GenID(ctx) - if err != nil { - return err + id, gErr := at.idGen.GenID(ctx) + if gErr != nil { + return gErr } m := &model.AgentToolDraft{ ID: id, diff --git a/backend/domain/plugin/internal/dal/agent_tool_version.go b/backend/domain/plugin/internal/dal/agent_tool_version.go index 75639b9f9..84787f0ba 100644 --- a/backend/domain/plugin/internal/dal/agent_tool_version.go +++ b/backend/domain/plugin/internal/dal/agent_tool_version.go @@ -24,6 +24,7 @@ import ( "gorm.io/gen" "gorm.io/gorm" + pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -94,7 +95,7 @@ func (at *AgentToolVersionDAO) GetWithToolName(ctx context.Context, agentID int6 return tool, true, nil } -func (at *AgentToolVersionDAO) Get(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { +func (at *AgentToolVersionDAO) Get(ctx context.Context, agentID int64, vAgentTool pluginModel.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { table := at.query.AgentToolVersion conds := []gen.Condition{ @@ -131,12 +132,12 @@ func (at *AgentToolVersionDAO) Get(ctx context.Context, agentID int64, vAgentToo return tool, true, nil } -func (at *AgentToolVersionDAO) MGet(ctx context.Context, agentID int64, vAgentTools []entity.VersionAgentTool) (tools []*entity.ToolInfo, err error) { +func (at *AgentToolVersionDAO) MGet(ctx context.Context, agentID int64, vAgentTools []pluginModel.VersionAgentTool) (tools []*entity.ToolInfo, err error) { tools = make([]*entity.ToolInfo, 0, len(vAgentTools)) table := at.query.AgentToolVersion chunks := slices.Chunks(vAgentTools, 20) - noVersion := make([]entity.VersionAgentTool, 0, len(vAgentTools)) + noVersion := make([]pluginModel.VersionAgentTool, 0, len(vAgentTools)) for _, chunk := range chunks { var q query.IAgentToolVersionDo @@ -198,9 +199,9 @@ func (at *AgentToolVersionDAO) BatchCreate(ctx context.Context, agentID int64, a return fmt.Errorf("invalid tool version") } - id, err := at.idGen.GenID(ctx) - if err != nil { - return err + id, mErr := at.idGen.GenID(ctx) + if mErr != nil { + return mErr } tls = append(tls, &model.AgentToolVersion{ diff --git a/backend/domain/plugin/internal/dal/model/agent_tool_draft.gen.go b/backend/domain/plugin/internal/dal/model/agent_tool_draft.gen.go index 77439858a..54e5c60ad 100644 --- a/backend/domain/plugin/internal/dal/model/agent_tool_draft.gen.go +++ b/backend/domain/plugin/internal/dal/model/agent_tool_draft.gen.go @@ -4,22 +4,24 @@ package model -import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" +import ( + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" +) const TableNameAgentToolDraft = "agent_tool_draft" // AgentToolDraft Draft Agent Tool type AgentToolDraft struct { - ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID - AgentID int64 `gorm:"column:agent_id;not null;comment:Agent ID" json:"agent_id"` // Agent ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - ToolID int64 `gorm:"column:tool_id;not null;comment:Tool ID" json:"tool_id"` // Tool ID - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path - Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method - ToolName string `gorm:"column:tool_name;not null;comment:Tool Name" json:"tool_name"` // Tool Name - ToolVersion string `gorm:"column:tool_version;not null;comment:Tool Version, e.g. v1.0.0" json:"tool_version"` // Tool Version, e.g. v1.0.0 - Operation *dto.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema + ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID + AgentID int64 `gorm:"column:agent_id;not null;comment:Agent ID" json:"agent_id"` // Agent ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + ToolID int64 `gorm:"column:tool_id;not null;comment:Tool ID" json:"tool_id"` // Tool ID + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path + Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method + ToolName string `gorm:"column:tool_name;not null;comment:Tool Name" json:"tool_name"` // Tool Name + ToolVersion string `gorm:"column:tool_version;not null;comment:Tool Version, e.g. v1.0.0" json:"tool_version"` // Tool Version, e.g. v1.0.0 + Operation *model.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema } // TableName AgentToolDraft's table name diff --git a/backend/domain/plugin/internal/dal/model/agent_tool_version.gen.go b/backend/domain/plugin/internal/dal/model/agent_tool_version.gen.go index 2fef4ad4a..e47a313da 100644 --- a/backend/domain/plugin/internal/dal/model/agent_tool_version.gen.go +++ b/backend/domain/plugin/internal/dal/model/agent_tool_version.gen.go @@ -4,23 +4,23 @@ package model -import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" const TableNameAgentToolVersion = "agent_tool_version" // AgentToolVersion Agent Tool Version type AgentToolVersion struct { - ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID - AgentID int64 `gorm:"column:agent_id;not null;comment:Agent ID" json:"agent_id"` // Agent ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - ToolID int64 `gorm:"column:tool_id;not null;comment:Tool ID" json:"tool_id"` // Tool ID - AgentVersion string `gorm:"column:agent_version;not null;comment:Agent Tool Version" json:"agent_version"` // Agent Tool Version - ToolName string `gorm:"column:tool_name;not null;comment:Tool Name" json:"tool_name"` // Tool Name - ToolVersion string `gorm:"column:tool_version;not null;comment:Tool Version, e.g. v1.0.0" json:"tool_version"` // Tool Version, e.g. v1.0.0 - SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path - Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method - Operation *dto.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID + AgentID int64 `gorm:"column:agent_id;not null;comment:Agent ID" json:"agent_id"` // Agent ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + ToolID int64 `gorm:"column:tool_id;not null;comment:Tool ID" json:"tool_id"` // Tool ID + AgentVersion string `gorm:"column:agent_version;not null;comment:Agent Tool Version" json:"agent_version"` // Agent Tool Version + ToolName string `gorm:"column:tool_name;not null;comment:Tool Name" json:"tool_name"` // Tool Name + ToolVersion string `gorm:"column:tool_version;not null;comment:Tool Version, e.g. v1.0.0" json:"tool_version"` // Tool Version, e.g. v1.0.0 + SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path + Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method + Operation *model.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds } // TableName AgentToolVersion's table name diff --git a/backend/domain/plugin/internal/dal/model/plugin.gen.go b/backend/domain/plugin/internal/dal/model/plugin.gen.go index d8ebfb491..c1c1ec981 100644 --- a/backend/domain/plugin/internal/dal/model/plugin.gen.go +++ b/backend/domain/plugin/internal/dal/model/plugin.gen.go @@ -4,25 +4,25 @@ package model -import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" const TableNamePlugin = "plugin" // Plugin Latest Plugin type Plugin struct { - ID int64 `gorm:"column:id;primaryKey;comment:Plugin ID" json:"id"` // Plugin ID - SpaceID int64 `gorm:"column:space_id;not null;comment:Space ID" json:"space_id"` // Space ID - DeveloperID int64 `gorm:"column:developer_id;not null;comment:Developer ID" json:"developer_id"` // Developer ID - AppID int64 `gorm:"column:app_id;not null;comment:Application ID" json:"app_id"` // Application ID - IconURI string `gorm:"column:icon_uri;not null;comment:Icon URI" json:"icon_uri"` // Icon URI - ServerURL string `gorm:"column:server_url;not null;comment:Server URL" json:"server_url"` // Server URL - PluginType int32 `gorm:"column:plugin_type;not null;comment:Plugin Type, 1:http, 6:local" json:"plugin_type"` // Plugin Type, 1:http, 6:local - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds - Version string `gorm:"column:version;not null;comment:Plugin Version, e.g. v1.0.0" json:"version"` // Plugin Version, e.g. v1.0.0 - VersionDesc string `gorm:"column:version_desc;comment:Plugin Version Description" json:"version_desc"` // Plugin Version Description - Manifest *dto.PluginManifest `gorm:"column:manifest;comment:Plugin Manifest;serializer:json" json:"manifest"` // Plugin Manifest - OpenapiDoc *dto.Openapi3T `gorm:"column:openapi_doc;comment:OpenAPI Document, only stores the root;serializer:json" json:"openapi_doc"` // OpenAPI Document, only stores the root + ID int64 `gorm:"column:id;primaryKey;comment:Plugin ID" json:"id"` // Plugin ID + SpaceID int64 `gorm:"column:space_id;not null;comment:Space ID" json:"space_id"` // Space ID + DeveloperID int64 `gorm:"column:developer_id;not null;comment:Developer ID" json:"developer_id"` // Developer ID + AppID int64 `gorm:"column:app_id;not null;comment:Application ID" json:"app_id"` // Application ID + IconURI string `gorm:"column:icon_uri;not null;comment:Icon URI" json:"icon_uri"` // Icon URI + ServerURL string `gorm:"column:server_url;not null;comment:Server URL" json:"server_url"` // Server URL + PluginType int32 `gorm:"column:plugin_type;not null;comment:Plugin Type, 1:http, 6:local" json:"plugin_type"` // Plugin Type, 1:http, 6:local + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds + Version string `gorm:"column:version;not null;comment:Plugin Version, e.g. v1.0.0" json:"version"` // Plugin Version, e.g. v1.0.0 + VersionDesc string `gorm:"column:version_desc;comment:Plugin Version Description" json:"version_desc"` // Plugin Version Description + Manifest *model.PluginManifest `gorm:"column:manifest;comment:Plugin Manifest;serializer:json" json:"manifest"` // Plugin Manifest + OpenapiDoc *model.Openapi3T `gorm:"column:openapi_doc;comment:OpenAPI Document, only stores the root;serializer:json" json:"openapi_doc"` // OpenAPI Document, only stores the root } // TableName Plugin's table name diff --git a/backend/domain/plugin/internal/dal/model/plugin_draft.gen.go b/backend/domain/plugin/internal/dal/model/plugin_draft.gen.go index 9282cf0bb..084244d90 100644 --- a/backend/domain/plugin/internal/dal/model/plugin_draft.gen.go +++ b/backend/domain/plugin/internal/dal/model/plugin_draft.gen.go @@ -5,7 +5,7 @@ package model import ( - "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "gorm.io/gorm" ) @@ -13,18 +13,18 @@ const TableNamePluginDraft = "plugin_draft" // PluginDraft Draft Plugin type PluginDraft struct { - ID int64 `gorm:"column:id;primaryKey;comment:Plugin ID" json:"id"` // Plugin ID - SpaceID int64 `gorm:"column:space_id;not null;comment:Space ID" json:"space_id"` // Space ID - DeveloperID int64 `gorm:"column:developer_id;not null;comment:Developer ID" json:"developer_id"` // Developer ID - AppID int64 `gorm:"column:app_id;not null;comment:Application ID" json:"app_id"` // Application ID - IconURI string `gorm:"column:icon_uri;not null;comment:Icon URI" json:"icon_uri"` // Icon URI - ServerURL string `gorm:"column:server_url;not null;comment:Server URL" json:"server_url"` // Server URL - PluginType int32 `gorm:"column:plugin_type;not null;comment:Plugin Type, 1:http, 6:local" json:"plugin_type"` // Plugin Type, 1:http, 6:local - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds - DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:Delete Time" json:"deleted_at"` // Delete Time - Manifest *dto.PluginManifest `gorm:"column:manifest;comment:Plugin Manifest;serializer:json" json:"manifest"` // Plugin Manifest - OpenapiDoc *dto.Openapi3T `gorm:"column:openapi_doc;comment:OpenAPI Document, only stores the root;serializer:json" json:"openapi_doc"` // OpenAPI Document, only stores the root + ID int64 `gorm:"column:id;primaryKey;comment:Plugin ID" json:"id"` // Plugin ID + SpaceID int64 `gorm:"column:space_id;not null;comment:Space ID" json:"space_id"` // Space ID + DeveloperID int64 `gorm:"column:developer_id;not null;comment:Developer ID" json:"developer_id"` // Developer ID + AppID int64 `gorm:"column:app_id;not null;comment:Application ID" json:"app_id"` // Application ID + IconURI string `gorm:"column:icon_uri;not null;comment:Icon URI" json:"icon_uri"` // Icon URI + ServerURL string `gorm:"column:server_url;not null;comment:Server URL" json:"server_url"` // Server URL + PluginType int32 `gorm:"column:plugin_type;not null;comment:Plugin Type, 1:http, 6:local" json:"plugin_type"` // Plugin Type, 1:http, 6:local + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:Delete Time" json:"deleted_at"` // Delete Time + Manifest *model.PluginManifest `gorm:"column:manifest;comment:Plugin Manifest;serializer:json" json:"manifest"` // Plugin Manifest + OpenapiDoc *model.Openapi3T `gorm:"column:openapi_doc;comment:OpenAPI Document, only stores the root;serializer:json" json:"openapi_doc"` // OpenAPI Document, only stores the root } // TableName PluginDraft's table name diff --git a/backend/domain/plugin/internal/dal/model/plugin_oauth_auth.gen.go b/backend/domain/plugin/internal/dal/model/plugin_oauth_auth.gen.go index 0310def73..6e72c831c 100644 --- a/backend/domain/plugin/internal/dal/model/plugin_oauth_auth.gen.go +++ b/backend/domain/plugin/internal/dal/model/plugin_oauth_auth.gen.go @@ -4,24 +4,24 @@ package model -import plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" const TableNamePluginOauthAuth = "plugin_oauth_auth" // PluginOauthAuth Plugin OAuth Authorization Code Info type PluginOauthAuth struct { - ID int64 `gorm:"column:id;primaryKey;comment:Primary Key" json:"id"` // Primary Key - UserID string `gorm:"column:user_id;not null;comment:User ID" json:"user_id"` // User ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - IsDraft bool `gorm:"column:is_draft;not null;comment:Is Draft Plugin" json:"is_draft"` // Is Draft Plugin - OauthConfig *plugin.OAuthAuthorizationCodeConfig `gorm:"column:oauth_config;comment:Authorization Code OAuth Config;serializer:json" json:"oauth_config"` // Authorization Code OAuth Config - AccessToken string `gorm:"column:access_token;not null;comment:Access Token" json:"access_token"` // Access Token - RefreshToken string `gorm:"column:refresh_token;not null;comment:Refresh Token" json:"refresh_token"` // Refresh Token - TokenExpiredAt int64 `gorm:"column:token_expired_at;comment:Token Expired in Milliseconds" json:"token_expired_at"` // Token Expired in Milliseconds - NextTokenRefreshAt int64 `gorm:"column:next_token_refresh_at;comment:Next Token Refresh Time in Milliseconds" json:"next_token_refresh_at"` // Next Token Refresh Time in Milliseconds - LastActiveAt int64 `gorm:"column:last_active_at;comment:Last active time in Milliseconds" json:"last_active_at"` // Last active time in Milliseconds - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds + ID int64 `gorm:"column:id;primaryKey;comment:Primary Key" json:"id"` // Primary Key + UserID string `gorm:"column:user_id;not null;comment:User ID" json:"user_id"` // User ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + IsDraft bool `gorm:"column:is_draft;not null;comment:Is Draft Plugin" json:"is_draft"` // Is Draft Plugin + OauthConfig *model.OAuthAuthorizationCodeConfig `gorm:"column:oauth_config;comment:Authorization Code OAuth Config;serializer:json" json:"oauth_config"` // Authorization Code OAuth Config + AccessToken string `gorm:"column:access_token;not null;comment:Access Token" json:"access_token"` // Access Token + RefreshToken string `gorm:"column:refresh_token;not null;comment:Refresh Token" json:"refresh_token"` // Refresh Token + TokenExpiredAt int64 `gorm:"column:token_expired_at;comment:Token Expired in Milliseconds" json:"token_expired_at"` // Token Expired in Milliseconds + NextTokenRefreshAt int64 `gorm:"column:next_token_refresh_at;comment:Next Token Refresh Time in Milliseconds" json:"next_token_refresh_at"` // Next Token Refresh Time in Milliseconds + LastActiveAt int64 `gorm:"column:last_active_at;comment:Last active time in Milliseconds" json:"last_active_at"` // Last active time in Milliseconds + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds } // TableName PluginOauthAuth's table name diff --git a/backend/domain/plugin/internal/dal/model/plugin_version.gen.go b/backend/domain/plugin/internal/dal/model/plugin_version.gen.go index 59538fef1..0a8ee7f79 100644 --- a/backend/domain/plugin/internal/dal/model/plugin_version.gen.go +++ b/backend/domain/plugin/internal/dal/model/plugin_version.gen.go @@ -5,7 +5,7 @@ package model import ( - "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "gorm.io/gorm" ) @@ -13,20 +13,20 @@ const TableNamePluginVersion = "plugin_version" // PluginVersion Plugin Version type PluginVersion struct { - ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID - SpaceID int64 `gorm:"column:space_id;not null;comment:Space ID" json:"space_id"` // Space ID - DeveloperID int64 `gorm:"column:developer_id;not null;comment:Developer ID" json:"developer_id"` // Developer ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - AppID int64 `gorm:"column:app_id;not null;comment:Application ID" json:"app_id"` // Application ID - IconURI string `gorm:"column:icon_uri;not null;comment:Icon URI" json:"icon_uri"` // Icon URI - ServerURL string `gorm:"column:server_url;not null;comment:Server URL" json:"server_url"` // Server URL - PluginType int32 `gorm:"column:plugin_type;not null;comment:Plugin Type, 1:http, 6:local" json:"plugin_type"` // Plugin Type, 1:http, 6:local - Version string `gorm:"column:version;not null;comment:Plugin Version, e.g. v1.0.0" json:"version"` // Plugin Version, e.g. v1.0.0 - VersionDesc string `gorm:"column:version_desc;comment:Plugin Version Description" json:"version_desc"` // Plugin Version Description - Manifest *dto.PluginManifest `gorm:"column:manifest;comment:Plugin Manifest;serializer:json" json:"manifest"` // Plugin Manifest - OpenapiDoc *dto.Openapi3T `gorm:"column:openapi_doc;comment:OpenAPI Document, only stores the root;serializer:json" json:"openapi_doc"` // OpenAPI Document, only stores the root - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:Delete Time" json:"deleted_at"` // Delete Time + ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID + SpaceID int64 `gorm:"column:space_id;not null;comment:Space ID" json:"space_id"` // Space ID + DeveloperID int64 `gorm:"column:developer_id;not null;comment:Developer ID" json:"developer_id"` // Developer ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + AppID int64 `gorm:"column:app_id;not null;comment:Application ID" json:"app_id"` // Application ID + IconURI string `gorm:"column:icon_uri;not null;comment:Icon URI" json:"icon_uri"` // Icon URI + ServerURL string `gorm:"column:server_url;not null;comment:Server URL" json:"server_url"` // Server URL + PluginType int32 `gorm:"column:plugin_type;not null;comment:Plugin Type, 1:http, 6:local" json:"plugin_type"` // Plugin Type, 1:http, 6:local + Version string `gorm:"column:version;not null;comment:Plugin Version, e.g. v1.0.0" json:"version"` // Plugin Version, e.g. v1.0.0 + VersionDesc string `gorm:"column:version_desc;comment:Plugin Version Description" json:"version_desc"` // Plugin Version Description + Manifest *model.PluginManifest `gorm:"column:manifest;comment:Plugin Manifest;serializer:json" json:"manifest"` // Plugin Manifest + OpenapiDoc *model.Openapi3T `gorm:"column:openapi_doc;comment:OpenAPI Document, only stores the root;serializer:json" json:"openapi_doc"` // OpenAPI Document, only stores the root + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:Delete Time" json:"deleted_at"` // Delete Time } // TableName PluginVersion's table name diff --git a/backend/domain/plugin/internal/dal/model/tool.gen.go b/backend/domain/plugin/internal/dal/model/tool.gen.go index e7bd23673..8fd09051a 100644 --- a/backend/domain/plugin/internal/dal/model/tool.gen.go +++ b/backend/domain/plugin/internal/dal/model/tool.gen.go @@ -4,21 +4,21 @@ package model -import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" const TableNameTool = "tool" // Tool Latest Tool type Tool struct { - ID int64 `gorm:"column:id;primaryKey;comment:Tool ID" json:"id"` // Tool ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds - Version string `gorm:"column:version;not null;comment:Tool Version, e.g. v1.0.0" json:"version"` // Tool Version, e.g. v1.0.0 - SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path - Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method - Operation *dto.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema - ActivatedStatus int32 `gorm:"column:activated_status;not null;comment:0:activated; 1:deactivated" json:"activated_status"` // 0:activated; 1:deactivated + ID int64 `gorm:"column:id;primaryKey;comment:Tool ID" json:"id"` // Tool ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds + Version string `gorm:"column:version;not null;comment:Tool Version, e.g. v1.0.0" json:"version"` // Tool Version, e.g. v1.0.0 + SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path + Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method + Operation *model.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema + ActivatedStatus int32 `gorm:"column:activated_status;not null;comment:0:activated; 1:deactivated" json:"activated_status"` // 0:activated; 1:deactivated } // TableName Tool's table name diff --git a/backend/domain/plugin/internal/dal/model/tool_draft.gen.go b/backend/domain/plugin/internal/dal/model/tool_draft.gen.go index 9f0ec2473..db28732de 100644 --- a/backend/domain/plugin/internal/dal/model/tool_draft.gen.go +++ b/backend/domain/plugin/internal/dal/model/tool_draft.gen.go @@ -4,21 +4,21 @@ package model -import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" +import "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" const TableNameToolDraft = "tool_draft" // ToolDraft Draft Tool type ToolDraft struct { - ID int64 `gorm:"column:id;primaryKey;comment:Tool ID" json:"id"` // Tool ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds - SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path - Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method - Operation *dto.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema - DebugStatus int32 `gorm:"column:debug_status;not null;comment:0:not pass; 1:pass" json:"debug_status"` // 0:not pass; 1:pass - ActivatedStatus int32 `gorm:"column:activated_status;not null;comment:0:activated; 1:deactivated" json:"activated_status"` // 0:activated; 1:deactivated + ID int64 `gorm:"column:id;primaryKey;comment:Tool ID" json:"id"` // Tool ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:Update Time in Milliseconds" json:"updated_at"` // Update Time in Milliseconds + SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path + Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method + Operation *model.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema + DebugStatus int32 `gorm:"column:debug_status;not null;comment:0:not pass; 1:pass" json:"debug_status"` // 0:not pass; 1:pass + ActivatedStatus int32 `gorm:"column:activated_status;not null;comment:0:activated; 1:deactivated" json:"activated_status"` // 0:activated; 1:deactivated } // TableName ToolDraft's table name diff --git a/backend/domain/plugin/internal/dal/model/tool_version.gen.go b/backend/domain/plugin/internal/dal/model/tool_version.gen.go index ce5f39cfe..0a5ed9d1d 100644 --- a/backend/domain/plugin/internal/dal/model/tool_version.gen.go +++ b/backend/domain/plugin/internal/dal/model/tool_version.gen.go @@ -5,7 +5,7 @@ package model import ( - "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "gorm.io/gorm" ) @@ -13,15 +13,15 @@ const TableNameToolVersion = "tool_version" // ToolVersion Tool Version type ToolVersion struct { - ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID - ToolID int64 `gorm:"column:tool_id;not null;comment:Tool ID" json:"tool_id"` // Tool ID - PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID - Version string `gorm:"column:version;not null;comment:Tool Version, e.g. v1.0.0" json:"version"` // Tool Version, e.g. v1.0.0 - SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path - Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method - Operation *dto.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema - CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds - DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:Delete Time" json:"deleted_at"` // Delete Time + ID int64 `gorm:"column:id;primaryKey;comment:Primary Key ID" json:"id"` // Primary Key ID + ToolID int64 `gorm:"column:tool_id;not null;comment:Tool ID" json:"tool_id"` // Tool ID + PluginID int64 `gorm:"column:plugin_id;not null;comment:Plugin ID" json:"plugin_id"` // Plugin ID + Version string `gorm:"column:version;not null;comment:Tool Version, e.g. v1.0.0" json:"version"` // Tool Version, e.g. v1.0.0 + SubURL string `gorm:"column:sub_url;not null;comment:Sub URL Path" json:"sub_url"` // Sub URL Path + Method string `gorm:"column:method;not null;comment:HTTP Request Method" json:"method"` // HTTP Request Method + Operation *model.Openapi3Operation `gorm:"column:operation;comment:Tool Openapi Operation Schema;serializer:json" json:"operation"` // Tool Openapi Operation Schema + CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:Create Time in Milliseconds" json:"created_at"` // Create Time in Milliseconds + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:Delete Time" json:"deleted_at"` // Delete Time } // TableName ToolVersion's table name diff --git a/backend/domain/plugin/internal/dal/plugin.go b/backend/domain/plugin/internal/dal/plugin.go index 097126c26..ccb98e471 100644 --- a/backend/domain/plugin/internal/dal/plugin.go +++ b/backend/domain/plugin/internal/dal/plugin.go @@ -26,7 +26,8 @@ import ( "gorm.io/gorm" plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - plugindto "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -49,7 +50,7 @@ type PluginDAO struct { type pluginPO model.Plugin func (p pluginPO) ToDO() *entity.PluginInfo { - return entity.NewPluginInfo(&plugindto.PluginInfo{ + return entity.NewPluginInfo(&pluginModel.PluginInfo{ ID: p.ID, SpaceID: p.SpaceID, DeveloperID: p.DeveloperID, @@ -132,7 +133,7 @@ func (p *PluginDAO) MGet(ctx context.Context, pluginIDs []int64, opt *PluginSele return plugins, nil } -func (p *PluginDAO) List(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { +func (p *PluginDAO) List(ctx context.Context, spaceID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { if pageInfo.SortBy == nil || pageInfo.OrderByACS == nil { return nil, 0, fmt.Errorf("sortBy or orderByACS is empty") } @@ -141,13 +142,13 @@ func (p *PluginDAO) List(ctx context.Context, spaceID int64, pageInfo entity.Pag table := p.query.Plugin switch *pageInfo.SortBy { - case entity.SortByCreatedAt: + case dto.SortByCreatedAt: if *pageInfo.OrderByACS { orderExpr = table.CreatedAt.Asc() } else { orderExpr = table.CreatedAt.Desc() } - case entity.SortByUpdatedAt: + case dto.SortByUpdatedAt: if *pageInfo.OrderByACS { orderExpr = table.UpdatedAt.Asc() } else { @@ -216,16 +217,16 @@ func (p *PluginDAO) UpsertWithTX(ctx context.Context, tx *query.QueryTx, pluginI updateMap[table.ServerURL.ColumnName().String()] = *pluginInfo.ServerURL } if pluginInfo.Manifest != nil { - b, err := json.Marshal(pluginInfo.Manifest) - if err != nil { - return err + b, mErr := json.Marshal(pluginInfo.Manifest) + if mErr != nil { + return mErr } updateMap[table.Manifest.ColumnName().String()] = b } if pluginInfo.OpenapiDoc != nil { - b, err := json.Marshal(pluginInfo.OpenapiDoc) - if err != nil { - return err + b, mErr := json.Marshal(pluginInfo.OpenapiDoc) + if mErr != nil { + return mErr } updateMap[table.OpenapiDoc.ColumnName().String()] = b } diff --git a/backend/domain/plugin/internal/dal/plugin_draft.go b/backend/domain/plugin/internal/dal/plugin_draft.go index 76b0ddd7a..e3973aeec 100644 --- a/backend/domain/plugin/internal/dal/plugin_draft.go +++ b/backend/domain/plugin/internal/dal/plugin_draft.go @@ -26,8 +26,9 @@ import ( "gorm.io/gorm" plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -130,7 +131,7 @@ func (p *PluginDraftDAO) genPluginID(ctx context.Context) (id int64, err error) break } if i == retryTimes-1 { - return 0, fmt.Errorf("id %d is confilict with product plugin id.", id) + return 0, fmt.Errorf("id %d is conflict with product plugin id", id) } } @@ -211,7 +212,7 @@ func (p *PluginDraftDAO) MGet(ctx context.Context, pluginIDs []int64, opt *Plugi return plugins, nil } -func (p *PluginDraftDAO) List(ctx context.Context, spaceID, appID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { +func (p *PluginDraftDAO) List(ctx context.Context, spaceID, appID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { if pageInfo.SortBy == nil || pageInfo.OrderByACS == nil { return nil, 0, fmt.Errorf("sortBy or orderByACS is empty") } @@ -220,13 +221,13 @@ func (p *PluginDraftDAO) List(ctx context.Context, spaceID, appID int64, pageInf table := p.query.PluginDraft switch *pageInfo.SortBy { - case entity.SortByCreatedAt: + case dto.SortByCreatedAt: if *pageInfo.OrderByACS { orderExpr = table.CreatedAt.Asc() } else { orderExpr = table.CreatedAt.Desc() } - case entity.SortByUpdatedAt: + case dto.SortByUpdatedAt: if *pageInfo.OrderByACS { orderExpr = table.UpdatedAt.Asc() } else { @@ -314,7 +315,7 @@ func (p *PluginDraftDAO) CreateWithTX(ctx context.Context, tx *query.QueryTx, pl return id, nil } -func (p *PluginDraftDAO) UpdateWithTX(ctx context.Context, tx *query.QueryTx, plugin *entity.PluginInfo) (err error) { +func (p *PluginDraftDAO) UpdateWithTX(ctx context.Context, tx *query.QueryTx, plugin *entity.PluginInfo) error { table := tx.PluginDraft updateMap := map[string]any{} @@ -348,7 +349,7 @@ func (p *PluginDraftDAO) UpdateWithTX(ctx context.Context, tx *query.QueryTx, pl updateMap[table.AppID.ColumnName().String()] = *plugin.APPID } - _, err = table.WithContext(ctx). + _, err := table.WithContext(ctx). Where(table.ID.Eq(plugin.ID)). UpdateColumns(updateMap) if err != nil { diff --git a/backend/domain/plugin/internal/dal/plugin_oauth_auth.go b/backend/domain/plugin/internal/dal/plugin_oauth_auth.go index 72cc78cc7..c195c54f1 100644 --- a/backend/domain/plugin/internal/dal/plugin_oauth_auth.go +++ b/backend/domain/plugin/internal/dal/plugin_oauth_auth.go @@ -25,8 +25,8 @@ import ( "gorm.io/gorm" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" @@ -42,7 +42,7 @@ func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAut type pluginOAuthAuthPO model.PluginOauthAuth -func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo { +func (p pluginOAuthAuthPO) ToDO() *dto.AuthorizationCodeInfo { secret := os.Getenv(encrypt.OAuthTokenSecretEnv) if secret == "" { secret = encrypt.DefaultOAuthTokenSecret @@ -61,9 +61,9 @@ func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo { } } - return &entity.AuthorizationCodeInfo{ + return &dto.AuthorizationCodeInfo{ RecordID: p.ID, - Meta: &entity.AuthorizationCodeMeta{ + Meta: &dto.AuthorizationCodeMeta{ UserID: p.UserID, PluginID: p.PluginID, IsDraft: p.IsDraft, @@ -82,7 +82,7 @@ type PluginOAuthAuthDAO struct { query *query.Query } -func (p *PluginOAuthAuthDAO) Get(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error) { +func (p *PluginOAuthAuthDAO) Get(ctx context.Context, meta *dto.AuthorizationCodeMeta) (info *dto.AuthorizationCodeInfo, exist bool, err error) { table := p.query.PluginOauthAuth res, err := table.WithContext(ctx). Where( @@ -103,7 +103,7 @@ func (p *PluginOAuthAuthDAO) Get(ctx context.Context, meta *entity.Authorization return info, true, nil } -func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) { +func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *dto.AuthorizationCodeInfo) (err error) { if info.Meta == nil || info.Meta.UserID == "" || info.Meta.PluginID <= 0 { return fmt.Errorf("meta info is required") } @@ -141,7 +141,8 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat return err } - id, err := p.idGen.GenID(ctx) + var id int64 + id, err = p.idGen.GenID(ctx) if err != nil { return err } @@ -179,9 +180,9 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat updateMap[table.LastActiveAt.ColumnName().String()] = info.LastActiveAtMS } if info.Config != nil { - b, err := json.Marshal(info.Config) - if err != nil { - return err + b, mErr := json.Marshal(info.Config) + if mErr != nil { + return mErr } updateMap[table.OauthConfig.ColumnName().String()] = b } @@ -197,7 +198,7 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat return err } -func (p *PluginOAuthAuthDAO) UpdateLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) { +func (p *PluginOAuthAuthDAO) UpdateLastActiveAt(ctx context.Context, meta *dto.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) { po := &model.PluginOauthAuth{ LastActiveAt: lastActiveAtMs, } @@ -214,11 +215,11 @@ func (p *PluginOAuthAuthDAO) UpdateLastActiveAt(ctx context.Context, meta *entit return err } -func (p *PluginOAuthAuthDAO) GetRefreshTokenList(ctx context.Context, nextRefreshAt int64, limit int) (infos []*entity.AuthorizationCodeInfo, err error) { +func (p *PluginOAuthAuthDAO) GetRefreshTokenList(ctx context.Context, nextRefreshAt int64, limit int) (infos []*dto.AuthorizationCodeInfo, err error) { const size = 50 table := p.query.PluginOauthAuth - infos = make([]*entity.AuthorizationCodeInfo, 0, limit) + infos = make([]*dto.AuthorizationCodeInfo, 0, limit) for limit > 0 { res, err := table.WithContext(ctx). @@ -233,7 +234,7 @@ func (p *PluginOAuthAuthDAO) GetRefreshTokenList(ctx context.Context, nextRefres return nil, err } - infos = make([]*entity.AuthorizationCodeInfo, 0, len(res)) + infos = make([]*dto.AuthorizationCodeInfo, 0, len(res)) for _, v := range res { infos = append(infos, pluginOAuthAuthPO(*v).ToDO()) } @@ -265,7 +266,7 @@ func (p *PluginOAuthAuthDAO) BatchDeleteByIDs(ctx context.Context, ids []int64) return nil } -func (p *PluginOAuthAuthDAO) Delete(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) { +func (p *PluginOAuthAuthDAO) Delete(ctx context.Context, meta *dto.AuthorizationCodeMeta) (err error) { table := p.query.PluginOauthAuth _, err = table.WithContext(ctx). Where( diff --git a/backend/domain/plugin/internal/dal/plugin_version.go b/backend/domain/plugin/internal/dal/plugin_version.go index e875e080c..eebdf44d3 100644 --- a/backend/domain/plugin/internal/dal/plugin_version.go +++ b/backend/domain/plugin/internal/dal/plugin_version.go @@ -25,7 +25,8 @@ import ( "gorm.io/gorm" plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -48,7 +49,7 @@ type PluginVersionDAO struct { type pluginVersionPO model.PluginVersion func (p pluginVersionPO) ToDO() *entity.PluginInfo { - return entity.NewPluginInfo(&dto.PluginInfo{ + return entity.NewPluginInfo(&pluginModel.PluginInfo{ ID: p.PluginID, SpaceID: p.SpaceID, APPID: &p.AppID, @@ -109,7 +110,7 @@ func (p *PluginVersionDAO) Get(ctx context.Context, pluginID int64, version stri return plugin, true, nil } -func (p *PluginVersionDAO) MGet(ctx context.Context, vPlugins []entity.VersionPlugin, opt *PluginSelectedOption) (plugins []*entity.PluginInfo, err error) { +func (p *PluginVersionDAO) MGet(ctx context.Context, vPlugins []pluginModel.VersionPlugin, opt *PluginSelectedOption) (plugins []*entity.PluginInfo, err error) { plugins = make([]*entity.PluginInfo, 0, len(vPlugins)) table := p.query.PluginVersion @@ -148,7 +149,7 @@ func (p *PluginVersionDAO) MGet(ctx context.Context, vPlugins []entity.VersionPl return plugins, nil } -func (p *PluginVersionDAO) ListVersions(ctx context.Context, pluginID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { +func (p *PluginVersionDAO) ListVersions(ctx context.Context, pluginID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { table := p.query.PluginVersion offset := (pageInfo.Page - 1) * pageInfo.Size diff --git a/backend/domain/plugin/internal/dal/tool.go b/backend/domain/plugin/internal/dal/tool.go index cbe4354d6..2db2ca8a3 100644 --- a/backend/domain/plugin/internal/dal/tool.go +++ b/backend/domain/plugin/internal/dal/tool.go @@ -24,7 +24,7 @@ import ( "gorm.io/gen/field" "gorm.io/gorm" - plugindto "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -57,7 +57,7 @@ func (t toolPO) ToDO() *entity.ToolInfo { SubURL: &t.SubURL, Method: ptr.Of(t.Method), Operation: t.Operation, - ActivatedStatus: ptr.Of(plugindto.ActivatedStatus(t.ActivatedStatus)), + ActivatedStatus: ptr.Of(consts.ActivatedStatus(t.ActivatedStatus)), } } diff --git a/backend/domain/plugin/internal/dal/tool_draft.go b/backend/domain/plugin/internal/dal/tool_draft.go index 492563492..680e26d02 100644 --- a/backend/domain/plugin/internal/dal/tool_draft.go +++ b/backend/domain/plugin/internal/dal/tool_draft.go @@ -26,8 +26,9 @@ import ( "gorm.io/gorm" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - plugindto "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -60,7 +61,7 @@ func (t toolDraftPO) ToDO() *entity.ToolInfo { Method: ptr.Of(t.Method), Operation: t.Operation, DebugStatus: ptr.Of(common.APIDebugStatus(t.DebugStatus)), - ActivatedStatus: ptr.Of(plugindto.ActivatedStatus(t.ActivatedStatus)), + ActivatedStatus: ptr.Of(consts.ActivatedStatus(t.ActivatedStatus)), } } @@ -174,7 +175,7 @@ func (t *ToolDraftDAO) MGet(ctx context.Context, toolIDs []int64, opt *ToolSelec return tools, nil } -func (t *ToolDraftDAO) GetWithAPI(ctx context.Context, pluginID int64, api entity.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) { +func (t *ToolDraftDAO) GetWithAPI(ctx context.Context, pluginID int64, api dto.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) { table := t.query.ToolDraft tl, err := table.WithContext(ctx). Where( @@ -195,8 +196,8 @@ func (t *ToolDraftDAO) GetWithAPI(ctx context.Context, pluginID int64, api entit return tool, true, nil } -func (t *ToolDraftDAO) MGetWithAPIs(ctx context.Context, pluginID int64, apis []entity.UniqueToolAPI, opt *ToolSelectedOption) (tools map[entity.UniqueToolAPI]*entity.ToolInfo, err error) { - tools = make(map[entity.UniqueToolAPI]*entity.ToolInfo, len(apis)) +func (t *ToolDraftDAO) MGetWithAPIs(ctx context.Context, pluginID int64, apis []dto.UniqueToolAPI, opt *ToolSelectedOption) (tools map[dto.UniqueToolAPI]*entity.ToolInfo, err error) { + tools = make(map[dto.UniqueToolAPI]*entity.ToolInfo, len(apis)) table := t.query.ToolDraft chunks := slices.Chunks(apis, 10) @@ -226,7 +227,7 @@ func (t *ToolDraftDAO) MGetWithAPIs(ctx context.Context, pluginID int64, apis [] return nil, err } for _, tl := range tls { - api := entity.UniqueToolAPI{ + api := dto.UniqueToolAPI{ SubURL: tl.SubURL, Method: tl.Method, } @@ -299,12 +300,12 @@ func (t *ToolDraftDAO) Update(ctx context.Context, tool *entity.ToolInfo) (err e return nil } -func (t *ToolDraftDAO) List(ctx context.Context, pluginID int64, pageInfo entity.PageInfo) (tools []*entity.ToolInfo, total int64, err error) { +func (t *ToolDraftDAO) List(ctx context.Context, pluginID int64, pageInfo dto.PageInfo) (tools []*entity.ToolInfo, total int64, err error) { if pageInfo.SortBy == nil || pageInfo.OrderByACS == nil { return nil, 0, fmt.Errorf("sortBy or orderByACS is empty") } - if *pageInfo.SortBy != entity.SortByCreatedAt { + if *pageInfo.SortBy != dto.SortByCreatedAt { return nil, 0, fmt.Errorf("invalid sortBy '%v'", *pageInfo.SortBy) } @@ -359,9 +360,9 @@ func (t *ToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, tls := make([]*model.ToolDraft, 0, len(tools)) for _, tool := range tools { - id, err := t.genToolID(ctx) - if err != nil { - return nil, err + id, mErr := t.genToolID(ctx) + if mErr != nil { + return nil, mErr } toolIDs = append(toolIDs, id) diff --git a/backend/domain/plugin/internal/dal/tool_version.go b/backend/domain/plugin/internal/dal/tool_version.go index badb0d340..27b43af21 100644 --- a/backend/domain/plugin/internal/dal/tool_version.go +++ b/backend/domain/plugin/internal/dal/tool_version.go @@ -22,6 +22,7 @@ import ( "gorm.io/gorm" + pluginModel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -56,7 +57,7 @@ func (t toolVersionPO) ToDO() *entity.ToolInfo { } } -func (t *ToolVersionDAO) Get(ctx context.Context, vTool entity.VersionTool) (tool *entity.ToolInfo, exist bool, err error) { +func (t *ToolVersionDAO) Get(ctx context.Context, vTool pluginModel.VersionTool) (tool *entity.ToolInfo, exist bool, err error) { table := t.query.ToolVersion if vTool.Version == "" { @@ -78,7 +79,7 @@ func (t *ToolVersionDAO) Get(ctx context.Context, vTool entity.VersionTool) (too return tool, true, nil } -func (t *ToolVersionDAO) MGet(ctx context.Context, vTools []entity.VersionTool) (tools []*entity.ToolInfo, err error) { +func (t *ToolVersionDAO) MGet(ctx context.Context, vTools []pluginModel.VersionTool) (tools []*entity.ToolInfo, err error) { tools = make([]*entity.ToolInfo, 0, len(vTools)) table := t.query.ToolVersion @@ -124,9 +125,9 @@ func (t *ToolVersionDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryT return fmt.Errorf("invalid tool version") } - id, err := t.idGen.GenID(ctx) - if err != nil { - return err + id, mErr := t.idGen.GenID(ctx) + if mErr != nil { + return mErr } tls = append(tls, &model.ToolVersion{ diff --git a/backend/domain/plugin/internal/encoder/req_encode.go b/backend/domain/plugin/internal/encoder/req_encode.go index 079acef59..8b4a55fda 100644 --- a/backend/domain/plugin/internal/encoder/req_encode.go +++ b/backend/domain/plugin/internal/encoder/req_encode.go @@ -27,16 +27,16 @@ import ( "github.com/shopspring/decimal" "gopkg.in/yaml.v3" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" ) func EncodeBodyWithContentType(contentType string, body map[string]any) ([]byte, error) { switch contentType { - case plugin.MediaTypeJson, plugin.MediaTypeProblemJson: + case consts.MediaTypeJson, consts.MediaTypeProblemJson: return jsonBodyEncoder(body) - case plugin.MediaTypeFormURLEncoded: + case consts.MediaTypeFormURLEncoded: return urlencodedBodyEncoder(body) - case plugin.MediaTypeYaml, plugin.MediaTypeXYaml: + case consts.MediaTypeYaml, consts.MediaTypeXYaml: return yamlBodyEncoder(body) default: return nil, fmt.Errorf("[EncodeBodyWithContentType] unsupported contentType=%s", contentType) diff --git a/backend/domain/plugin/internal/openapi/convert_protocol.go b/backend/domain/plugin/internal/openapi/convert_protocol.go index 90e2feabf..0fb03ee7f 100644 --- a/backend/domain/plugin/internal/openapi/convert_protocol.go +++ b/backend/domain/plugin/internal/openapi/convert_protocol.go @@ -34,15 +34,15 @@ import ( postman "github.com/rbretecher/go-postman-collection" "gopkg.in/yaml.v3" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/errno" ) -func CurlToOpenapi3Doc(ctx context.Context, rawCURL string) (doc *model.Openapi3T, mf *entity.PluginManifest, err error) { +func CurlToOpenapi3Doc(ctx context.Context, rawCURL string) (doc *model.Openapi3T, mf *model.PluginManifest, err error) { curlReq, err := parseCURL(ctx, rawCURL) if err != nil { return nil, nil, err @@ -55,7 +55,7 @@ func CurlToOpenapi3Doc(ctx context.Context, rawCURL string) (doc *model.Openapi3 return nil, nil, err } - doc = entity.NewDefaultOpenapiDoc() + doc = model.NewDefaultOpenapiDoc() doc.Servers = append(doc.Servers, &openapi3.Server{ URL: urlSchema.Scheme + "://" + urlSchema.Host, }) @@ -68,7 +68,7 @@ func CurlToOpenapi3Doc(ctx context.Context, rawCURL string) (doc *model.Openapi3 OperationID: operationID, Summary: curlReq.Method + ":" + urlSchema.Path, Parameters: openapi3.Parameters{}, - Responses: entity.DefaultOpenapi3Responses(), + Responses: model.DefaultOpenapi3Responses(), } if len(curlReq.Header) > 0 { @@ -99,7 +99,7 @@ func CurlToOpenapi3Doc(ctx context.Context, rawCURL string) (doc *model.Openapi3 fillNecessaryInfoForOpenapi3Doc(doc) - mf = entity.NewDefaultPluginManifest() + mf = model.NewDefaultPluginManifest() fillManifestWithOpenapiDoc(mf, doc) return doc, mf, nil @@ -278,26 +278,26 @@ func (c *curlRequest) parseCURLData(curIdx int, lines []string) (nxtIdx int, err if len(ct) > 0 { mediaType = ct[0] } else { - mediaType = model.MediaTypeFormURLEncoded + mediaType = consts.MediaTypeFormURLEncoded c.Header["Content-Type"] = append(c.Header["Content-Type"], mediaType) } data := lines[curIdx+1] switch mediaType { - case model.MediaTypeFormURLEncoded: + case consts.MediaTypeFormURLEncoded: err = c.decodeFormUrlEncodedDataBody(data) if err != nil { return 0, err } - case model.MediaTypeJson, model.MediaTypeProblemJson: + case consts.MediaTypeJson, consts.MediaTypeProblemJson: err = c.decodeJsonDataBody(data) if err != nil { return 0, err } - case model.MediaTypeYaml, model.MediaTypeXYaml: + case consts.MediaTypeYaml, consts.MediaTypeXYaml: err = c.decodeYamlDataBody(data) if err != nil { return 0, err @@ -574,7 +574,7 @@ func curlBodyToOpenAPI(ctx context.Context, mediaType string, bodyValue any, op } if mediaType == "" { - mediaType = model.MediaTypeJson + mediaType = consts.MediaTypeJson } op.RequestBody = &openapi3.RequestBodyRef{ @@ -590,7 +590,7 @@ func curlBodyToOpenAPI(ctx context.Context, mediaType string, bodyValue any, op return op, nil } -func PostmanToOpenapi3Doc(ctx context.Context, rawPostman string) (doc *model.Openapi3T, mf *entity.PluginManifest, err error) { +func PostmanToOpenapi3Doc(ctx context.Context, rawPostman string) (doc *model.Openapi3T, mf *model.PluginManifest, err error) { collection, err := postman.ParseCollection(bytes.NewBufferString(rawPostman)) if err != nil { return nil, nil, errorx.New(errno.ErrPluginConvertProtocolFailed, @@ -616,7 +616,7 @@ func PostmanToOpenapi3Doc(ctx context.Context, rawPostman string) (doc *model.Op "invalid request url '%s', url must start with 'http://' or 'https://'", rawURL)) } - doc = entity.NewDefaultOpenapiDoc() + doc = model.NewDefaultOpenapiDoc() doc.Servers = append(doc.Servers, &openapi3.Server{ URL: urlSchema.Scheme + "://" + urlSchema.Host, }) @@ -636,7 +636,7 @@ func PostmanToOpenapi3Doc(ctx context.Context, rawPostman string) (doc *model.Op OperationID: item.Name, Summary: item.Description, Parameters: openapi3.Parameters{}, - Responses: entity.DefaultOpenapi3Responses(), + Responses: model.DefaultOpenapi3Responses(), } var mediaType string @@ -683,7 +683,7 @@ func PostmanToOpenapi3Doc(ctx context.Context, rawPostman string) (doc *model.Op fillNecessaryInfoForOpenapi3Doc(doc) - mf = entity.NewDefaultPluginManifest() + mf = model.NewDefaultPluginManifest() fillManifestWithOpenapiDoc(mf, doc) return doc, mf, nil @@ -772,27 +772,27 @@ func postmanBodyToOpenAPI(ctx context.Context, mediaType string, body *postman.B } if mediaType == "" { - mediaType = model.MediaTypeJson + mediaType = consts.MediaTypeJson if body.Mode == "urlencoded" { - mediaType = model.MediaTypeFormURLEncoded + mediaType = consts.MediaTypeFormURLEncoded } } var valMap map[string]any switch mediaType { - case model.MediaTypeJson, model.MediaTypeProblemJson: + case consts.MediaTypeJson, consts.MediaTypeProblemJson: valMap, err = decodeRequestJsonBody(body.Raw) if err != nil { return nil, err } - case model.MediaTypeYaml, model.MediaTypeXYaml: + case consts.MediaTypeYaml, consts.MediaTypeXYaml: valMap, err = decodeRequestYamlBody(body.Raw) if err != nil { return nil, err } - case model.MediaTypeFormURLEncoded: + case consts.MediaTypeFormURLEncoded: valMap, err = decodePostmanRequestFormURLEncodedBody(body.URLEncoded) if err != nil { return nil, err @@ -891,7 +891,7 @@ func decodePostmanRequestFormURLEncodedBody(rawBody any) (body map[string]any, e return body, nil } -func SwaggerToOpenapi3Doc(_ context.Context, rawSwagger string) (doc *model.Openapi3T, mf *entity.PluginManifest, err error) { +func SwaggerToOpenapi3Doc(_ context.Context, rawSwagger string) (doc *model.Openapi3T, mf *model.PluginManifest, err error) { doc2 := &openapi2.T{} if err = json.Unmarshal([]byte(rawSwagger), doc2); err != nil { err = yaml.Unmarshal([]byte(rawSwagger), doc2) @@ -909,13 +909,13 @@ func SwaggerToOpenapi3Doc(_ context.Context, rawSwagger string) (doc *model.Open doc = ptr.Of(model.Openapi3T(*doc3)) fillNecessaryInfoForOpenapi3Doc(doc) - mf = entity.NewDefaultPluginManifest() + mf = model.NewDefaultPluginManifest() fillManifestWithOpenapiDoc(mf, doc) return doc, mf, nil } -func ToOpenapi3Doc(_ context.Context, rawOpenAPI string) (doc *model.Openapi3T, mf *entity.PluginManifest, err error) { +func ToOpenapi3Doc(_ context.Context, rawOpenAPI string) (doc *model.Openapi3T, mf *model.PluginManifest, err error) { loader := openapi3.NewLoader() doc3, err := loader.LoadFromData([]byte(rawOpenAPI)) if err != nil { @@ -926,13 +926,13 @@ func ToOpenapi3Doc(_ context.Context, rawOpenAPI string) (doc *model.Openapi3T, doc = ptr.Of(model.Openapi3T(*doc3)) fillNecessaryInfoForOpenapi3Doc(doc) - mf = entity.NewDefaultPluginManifest() + mf = model.NewDefaultPluginManifest() fillManifestWithOpenapiDoc(mf, doc) return doc, mf, nil } -func fillManifestWithOpenapiDoc(mf *entity.PluginManifest, doc *model.Openapi3T) { +func fillManifestWithOpenapiDoc(mf *model.PluginManifest, doc *model.Openapi3T) { if doc.Info == nil { return } @@ -981,14 +981,14 @@ func fillNecessaryInfoForOpenapi3Doc(doc *model.Openapi3T) { } if op.Responses != nil { - defaultResp := entity.DefaultOpenapi3Responses() + defaultResp := model.DefaultOpenapi3Responses() respRef := op.Responses[strconv.Itoa(http.StatusOK)] if respRef == nil || respRef.Value == nil || respRef.Value.Content == nil { op.Responses = defaultResp respRef = op.Responses[strconv.Itoa(http.StatusOK)] } - if respRef.Value.Content[model.MediaTypeJson] == nil { - respRef.Value.Content[model.MediaTypeJson] = defaultResp[strconv.Itoa(http.StatusOK)].Value.Content[model.MediaTypeJson] + if respRef.Value.Content[consts.MediaTypeJson] == nil { + respRef.Value.Content[consts.MediaTypeJson] = defaultResp[strconv.Itoa(http.StatusOK)].Value.Content[consts.MediaTypeJson] } } } diff --git a/backend/domain/plugin/repository/mock/mock_oauth_repository.go b/backend/domain/plugin/repository/mock/mock_oauth_repository.go index 3366039b2..acc957668 100644 --- a/backend/domain/plugin/repository/mock/mock_oauth_repository.go +++ b/backend/domain/plugin/repository/mock/mock_oauth_repository.go @@ -13,7 +13,7 @@ import ( context "context" reflect "reflect" - entity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + dto "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" gomock "go.uber.org/mock/gomock" ) @@ -21,6 +21,7 @@ import ( type MockOAuthRepository struct { ctrl *gomock.Controller recorder *MockOAuthRepositoryMockRecorder + isgomock struct{} } // MockOAuthRepositoryMockRecorder is the mock recorder for MockOAuthRepository. @@ -55,7 +56,7 @@ func (mr *MockOAuthRepositoryMockRecorder) BatchDeleteAuthorizationCodeByIDs(ctx } // DeleteAuthorizationCode mocks base method. -func (m *MockOAuthRepository) DeleteAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) error { +func (m *MockOAuthRepository) DeleteAuthorizationCode(ctx context.Context, meta *dto.AuthorizationCodeMeta) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteAuthorizationCode", ctx, meta) ret0, _ := ret[0].(error) @@ -97,10 +98,10 @@ func (mr *MockOAuthRepositoryMockRecorder) DeleteInactiveAuthorizationCodeTokens } // GetAuthorizationCode mocks base method. -func (m *MockOAuthRepository) GetAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (*entity.AuthorizationCodeInfo, bool, error) { +func (m *MockOAuthRepository) GetAuthorizationCode(ctx context.Context, meta *dto.AuthorizationCodeMeta) (*dto.AuthorizationCodeInfo, bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAuthorizationCode", ctx, meta) - ret0, _ := ret[0].(*entity.AuthorizationCodeInfo) + ret0, _ := ret[0].(*dto.AuthorizationCodeInfo) ret1, _ := ret[1].(bool) ret2, _ := ret[2].(error) return ret0, ret1, ret2 @@ -113,10 +114,10 @@ func (mr *MockOAuthRepositoryMockRecorder) GetAuthorizationCode(ctx, meta any) * } // GetAuthorizationCodeRefreshTokens mocks base method. -func (m *MockOAuthRepository) GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) ([]*entity.AuthorizationCodeInfo, error) { +func (m *MockOAuthRepository) GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) ([]*dto.AuthorizationCodeInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAuthorizationCodeRefreshTokens", ctx, nextRefreshAt, limit) - ret0, _ := ret[0].([]*entity.AuthorizationCodeInfo) + ret0, _ := ret[0].([]*dto.AuthorizationCodeInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -128,7 +129,7 @@ func (mr *MockOAuthRepositoryMockRecorder) GetAuthorizationCodeRefreshTokens(ctx } // UpdateAuthorizationCodeLastActiveAt mocks base method. -func (m *MockOAuthRepository) UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) error { +func (m *MockOAuthRepository) UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *dto.AuthorizationCodeMeta, lastActiveAtMs int64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateAuthorizationCodeLastActiveAt", ctx, meta, lastActiveAtMs) ret0, _ := ret[0].(error) @@ -142,7 +143,7 @@ func (mr *MockOAuthRepositoryMockRecorder) UpdateAuthorizationCodeLastActiveAt(c } // UpsertAuthorizationCode mocks base method. -func (m *MockOAuthRepository) UpsertAuthorizationCode(ctx context.Context, info *entity.AuthorizationCodeInfo) error { +func (m *MockOAuthRepository) UpsertAuthorizationCode(ctx context.Context, info *dto.AuthorizationCodeInfo) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpsertAuthorizationCode", ctx, info) ret0, _ := ret[0].(error) diff --git a/backend/domain/plugin/repository/oauth_impl.go b/backend/domain/plugin/repository/oauth_impl.go index 92fe6578c..41b785e53 100644 --- a/backend/domain/plugin/repository/oauth_impl.go +++ b/backend/domain/plugin/repository/oauth_impl.go @@ -21,7 +21,7 @@ import ( "gorm.io/gorm" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" ) @@ -41,15 +41,15 @@ type oauthRepoImpl struct { oauthAuth *dal.PluginOAuthAuthDAO } -func (o *oauthRepoImpl) GetAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error) { +func (o *oauthRepoImpl) GetAuthorizationCode(ctx context.Context, meta *dto.AuthorizationCodeMeta) (info *dto.AuthorizationCodeInfo, exist bool, err error) { return o.oauthAuth.Get(ctx, meta) } -func (o *oauthRepoImpl) UpsertAuthorizationCode(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) { +func (o *oauthRepoImpl) UpsertAuthorizationCode(ctx context.Context, info *dto.AuthorizationCodeInfo) (err error) { return o.oauthAuth.Upsert(ctx, info) } -func (o *oauthRepoImpl) UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) { +func (o *oauthRepoImpl) UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *dto.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) { return o.oauthAuth.UpdateLastActiveAt(ctx, meta, lastActiveAtMs) } @@ -57,11 +57,11 @@ func (o *oauthRepoImpl) BatchDeleteAuthorizationCodeByIDs(ctx context.Context, i return o.oauthAuth.BatchDeleteByIDs(ctx, ids) } -func (o *oauthRepoImpl) DeleteAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) { +func (o *oauthRepoImpl) DeleteAuthorizationCode(ctx context.Context, meta *dto.AuthorizationCodeMeta) (err error) { return o.oauthAuth.Delete(ctx, meta) } -func (o *oauthRepoImpl) GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) (infos []*entity.AuthorizationCodeInfo, err error) { +func (o *oauthRepoImpl) GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) (infos []*dto.AuthorizationCodeInfo, err error) { return o.oauthAuth.GetRefreshTokenList(ctx, nextRefreshAt, limit) } diff --git a/backend/domain/plugin/repository/oauth_repository.go b/backend/domain/plugin/repository/oauth_repository.go index 5495adfa5..9140c61fa 100644 --- a/backend/domain/plugin/repository/oauth_repository.go +++ b/backend/domain/plugin/repository/oauth_repository.go @@ -19,17 +19,17 @@ package repository import ( "context" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" ) //go:generate mockgen -source=./oauth_repository.go -package=mock_plugin_oauth -destination=./mock/mock_oauth_repository.go type OAuthRepository interface { - GetAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error) - UpsertAuthorizationCode(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) - UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) + GetAuthorizationCode(ctx context.Context, meta *dto.AuthorizationCodeMeta) (info *dto.AuthorizationCodeInfo, exist bool, err error) + UpsertAuthorizationCode(ctx context.Context, info *dto.AuthorizationCodeInfo) (err error) + UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *dto.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) BatchDeleteAuthorizationCodeByIDs(ctx context.Context, ids []int64) (err error) - DeleteAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) - GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) (infos []*entity.AuthorizationCodeInfo, err error) + DeleteAuthorizationCode(ctx context.Context, meta *dto.AuthorizationCodeMeta) (err error) + GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) (infos []*dto.AuthorizationCodeInfo, err error) DeleteExpiredAuthorizationCodeTokens(ctx context.Context, expireAt int64, limit int) (err error) DeleteInactiveAuthorizationCodeTokens(ctx context.Context, lastActiveAt int64, limit int) (err error) } diff --git a/backend/domain/plugin/repository/plugin_impl.go b/backend/domain/plugin/repository/plugin_impl.go index 1ff36b2d5..6b47cd6b0 100644 --- a/backend/domain/plugin/repository/plugin_impl.go +++ b/backend/domain/plugin/repository/plugin_impl.go @@ -26,8 +26,11 @@ import ( "gorm.io/gorm" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -223,11 +226,11 @@ func (p *pluginRepoImpl) MGetOnlinePlugins(ctx context.Context, pluginIDs []int6 return plugins, nil } -func (p *pluginRepoImpl) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { +func (p *pluginRepoImpl) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { return p.pluginDAO.List(ctx, spaceID, pageInfo) } -func (p *pluginRepoImpl) GetVersionPlugin(ctx context.Context, vPlugin entity.VersionPlugin) (plugin *entity.PluginInfo, exist bool, err error) { +func (p *pluginRepoImpl) GetVersionPlugin(ctx context.Context, vPlugin model.VersionPlugin) (plugin *entity.PluginInfo, exist bool, err error) { pi, exist := pluginConf.GetPluginProduct(vPlugin.PluginID) if exist { return entity.NewPluginInfo(pi.Info), true, nil @@ -236,7 +239,7 @@ func (p *pluginRepoImpl) GetVersionPlugin(ctx context.Context, vPlugin entity.Ve return p.pluginVersionDAO.Get(ctx, vPlugin.PluginID, vPlugin.Version) } -func (p *pluginRepoImpl) MGetVersionPlugins(ctx context.Context, vPlugins []entity.VersionPlugin, opts ...PluginSelectedOptions) (plugins []*entity.PluginInfo, err error) { +func (p *pluginRepoImpl) MGetVersionPlugins(ctx context.Context, vPlugins []model.VersionPlugin, opts ...PluginSelectedOptions) (plugins []*entity.PluginInfo, err error) { pluginIDs := make([]int64, 0, len(vPlugins)) for _, vPlugin := range vPlugins { pluginIDs = append(pluginIDs, vPlugin.PluginID) @@ -250,7 +253,7 @@ func (p *pluginRepoImpl) MGetVersionPlugins(ctx context.Context, vPlugins []enti return plugin.Info.ID, true }) - vCustomPlugins := make([]entity.VersionPlugin, 0, len(pluginIDs)) + vCustomPlugins := make([]model.VersionPlugin, 0, len(pluginIDs)) for _, v := range vPlugins { _, ok := productPluginIDs[v.PluginID] if ok { @@ -285,7 +288,7 @@ func (p *pluginRepoImpl) PublishPlugin(ctx context.Context, draftPlugin *entity. activatedTools := make([]*entity.ToolInfo, 0, len(draftTools)) for _, tool := range draftTools { - if tool.GetActivatedStatus() == model.DeactivateTool { + if tool.IsDeactivated() { continue } @@ -354,14 +357,14 @@ func (p *pluginRepoImpl) PublishPlugins(ctx context.Context, draftPlugins []*ent pluginTools := make(map[int64][]*entity.ToolInfo, len(draftPlugins)) for _, draftPlugin := range draftPlugins { - draftTools, err := p.toolDraftDAO.GetAll(ctx, draftPlugin.ID, nil) - if err != nil { - return err + draftTools, mErr := p.toolDraftDAO.GetAll(ctx, draftPlugin.ID, nil) + if mErr != nil { + return mErr } activatedTools := make([]*entity.ToolInfo, 0, len(draftTools)) for _, tool := range draftTools { - if tool.GetActivatedStatus() == model.DeactivateTool { + if tool.IsDeactivated() { continue } @@ -619,7 +622,7 @@ func (p *pluginRepoImpl) CreateDraftPluginWithCode(ctx context.Context, req *Cre doc := req.OpenapiDoc mf := req.Manifest - pluginType, _ := model.ToThriftPluginType(mf.API.Type) + pluginType, _ := convert.ToThriftPluginType(mf.API.Type) pl := entity.NewPluginInfo(&model.PluginInfo{ PluginType: pluginType, @@ -636,7 +639,7 @@ func (p *pluginRepoImpl) CreateDraftPluginWithCode(ctx context.Context, req *Cre for subURL, pathItem := range doc.Paths { for method, op := range pathItem.Operations() { tools = append(tools, &entity.ToolInfo{ - ActivatedStatus: ptr.Of(model.ActivateTool), + ActivatedStatus: ptr.Of(consts.ActivateTool), DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting), SubURL: ptr.Of(subURL), Method: ptr.Of(method), @@ -746,8 +749,8 @@ func (p *pluginRepoImpl) CopyPlugin(ctx context.Context, req *CopyPluginRequest) // publish plugin filteredTools := make([]*entity.ToolInfo, 0, len(tools)) for _, tool := range tools { - if tool.GetActivatedStatus() == model.DeactivateTool || - tool.GetDebugStatus() == common.APIDebugStatus_DebugWaiting { + if tool.IsDeactivated() || + tool.IsDebugging() { continue } filteredTools = append(filteredTools, tool) diff --git a/backend/domain/plugin/repository/plugin_repository.go b/backend/domain/plugin/repository/plugin_repository.go index 0b27cff37..be88c12d9 100644 --- a/backend/domain/plugin/repository/plugin_repository.go +++ b/backend/domain/plugin/repository/plugin_repository.go @@ -19,7 +19,8 @@ package repository import ( "context" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" ) @@ -35,14 +36,14 @@ type PluginRepository interface { UpdateDraftPluginWithCode(ctx context.Context, req *UpdatePluginDraftWithCode) (err error) DeleteDraftPlugin(ctx context.Context, pluginID int64) (err error) DeleteAPPAllPlugins(ctx context.Context, appID int64) (pluginIDs []int64, err error) - UpdateDebugExample(ctx context.Context, pluginID int64, openapiDoc *plugin.Openapi3T) (err error) + UpdateDebugExample(ctx context.Context, pluginID int64, openapiDoc *model.Openapi3T) (err error) GetOnlinePlugin(ctx context.Context, pluginID int64, opts ...PluginSelectedOptions) (plugin *entity.PluginInfo, exist bool, err error) MGetOnlinePlugins(ctx context.Context, pluginIDs []int64, opts ...PluginSelectedOptions) (plugins []*entity.PluginInfo, err error) - ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) + ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) - GetVersionPlugin(ctx context.Context, vPlugin entity.VersionPlugin) (plugin *entity.PluginInfo, exist bool, err error) - MGetVersionPlugins(ctx context.Context, vPlugins []entity.VersionPlugin, opts ...PluginSelectedOptions) (plugin []*entity.PluginInfo, err error) + GetVersionPlugin(ctx context.Context, vPlugin model.VersionPlugin) (plugin *entity.PluginInfo, exist bool, err error) + MGetVersionPlugins(ctx context.Context, vPlugins []model.VersionPlugin, opts ...PluginSelectedOptions) (plugin []*entity.PluginInfo, err error) PublishPlugin(ctx context.Context, draftPlugin *entity.PluginInfo) (err error) PublishPlugins(ctx context.Context, draftPlugins []*entity.PluginInfo) (err error) @@ -53,8 +54,8 @@ type PluginRepository interface { type UpdatePluginDraftWithCode struct { PluginID int64 - OpenapiDoc *plugin.Openapi3T - Manifest *entity.PluginManifest + OpenapiDoc *model.Openapi3T + Manifest *model.PluginManifest UpdatedTools []*entity.ToolInfo NewDraftTools []*entity.ToolInfo @@ -64,8 +65,8 @@ type CreateDraftPluginWithCodeRequest struct { SpaceID int64 DeveloperID int64 ProjectID *int64 - Manifest *entity.PluginManifest - OpenapiDoc *plugin.Openapi3T + Manifest *model.PluginManifest + OpenapiDoc *model.Openapi3T } type CreateDraftPluginWithCodeResponse struct { @@ -76,7 +77,7 @@ type CreateDraftPluginWithCodeResponse struct { type ListDraftPluginsRequest struct { SpaceID int64 APPID int64 - PageInfo entity.PageInfo + PageInfo dto.PageInfo } type ListDraftPluginsResponse struct { diff --git a/backend/domain/plugin/repository/tool_impl.go b/backend/domain/plugin/repository/tool_impl.go index b5baad464..f149c3a08 100644 --- a/backend/domain/plugin/repository/tool_impl.go +++ b/backend/domain/plugin/repository/tool_impl.go @@ -23,7 +23,9 @@ import ( "gorm.io/gorm" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" @@ -66,8 +68,8 @@ func (t *toolRepoImpl) CreateDraftTool(ctx context.Context, tool *entity.ToolInf } func (t *toolRepoImpl) UpsertDraftTools(ctx context.Context, pluginID int64, tools []*entity.ToolInfo) (err error) { - apis := slices.Transform(tools, func(tool *entity.ToolInfo) entity.UniqueToolAPI { - return entity.UniqueToolAPI{ + apis := slices.Transform(tools, func(tool *entity.ToolInfo) dto.UniqueToolAPI { + return dto.UniqueToolAPI{ SubURL: tool.GetSubURL(), Method: tool.GetMethod(), } @@ -102,7 +104,7 @@ func (t *toolRepoImpl) UpsertDraftTools(ctx context.Context, pluginID int64, too updatedTools := make([]*entity.ToolInfo, 0, len(existTools)) for _, tool := range tools { - existTool, exist := existTools[entity.UniqueToolAPI{ + existTool, exist := existTools[dto.UniqueToolAPI{ SubURL: tool.GetSubURL(), Method: tool.GetMethod(), }] @@ -182,15 +184,15 @@ func (t *toolRepoImpl) GetPluginAllOnlineTools(ctx context.Context, pluginID int return tools, nil } -func (t *toolRepoImpl) ListPluginDraftTools(ctx context.Context, pluginID int64, pageInfo entity.PageInfo) (tools []*entity.ToolInfo, total int64, err error) { +func (t *toolRepoImpl) ListPluginDraftTools(ctx context.Context, pluginID int64, pageInfo dto.PageInfo) (tools []*entity.ToolInfo, total int64, err error) { return t.toolDraftDAO.List(ctx, pluginID, pageInfo) } -func (t *toolRepoImpl) GetDraftToolWithAPI(ctx context.Context, pluginID int64, api entity.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) { +func (t *toolRepoImpl) GetDraftToolWithAPI(ctx context.Context, pluginID int64, api dto.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) { return t.toolDraftDAO.GetWithAPI(ctx, pluginID, api) } -func (t *toolRepoImpl) MGetDraftToolWithAPI(ctx context.Context, pluginID int64, apis []entity.UniqueToolAPI, opts ...ToolSelectedOptions) (tools map[entity.UniqueToolAPI]*entity.ToolInfo, err error) { +func (t *toolRepoImpl) MGetDraftToolWithAPI(ctx context.Context, pluginID int64, apis []dto.UniqueToolAPI, opts ...ToolSelectedOptions) (tools map[dto.UniqueToolAPI]*entity.ToolInfo, err error) { var opt *dal.ToolSelectedOption if len(opts) > 0 { opt = &dal.ToolSelectedOption{} @@ -251,7 +253,7 @@ func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opt return tools, nil } -func (t *toolRepoImpl) GetVersionTool(ctx context.Context, vTool entity.VersionTool) (tool *entity.ToolInfo, exist bool, err error) { +func (t *toolRepoImpl) GetVersionTool(ctx context.Context, vTool model.VersionTool) (tool *entity.ToolInfo, exist bool, err error) { ti, exist := pluginConf.GetToolProduct(vTool.ToolID) if exist { return ti.Info, true, nil @@ -260,7 +262,7 @@ func (t *toolRepoImpl) GetVersionTool(ctx context.Context, vTool entity.VersionT return t.toolVersionDAO.Get(ctx, vTool) } -func (t *toolRepoImpl) MGetVersionTools(ctx context.Context, versionTools []entity.VersionTool) (tools []*entity.ToolInfo, err error) { +func (t *toolRepoImpl) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*entity.ToolInfo, err error) { tools, err = t.toolVersionDAO.MGet(ctx, versionTools) if err != nil { return nil, err @@ -411,7 +413,7 @@ func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID i return t.agentToolDraftDAO.GetAll(ctx, agentID, nil) } -func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { +func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool model.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { return t.agentToolVersionDAO.Get(ctx, agentID, vAgentTool) } @@ -419,7 +421,7 @@ func (t *toolRepoImpl) GetVersionAgentToolWithToolName(ctx context.Context, req return t.agentToolVersionDAO.GetWithToolName(ctx, req.AgentID, req.ToolName, req.AgentVersion) } -func (t *toolRepoImpl) MGetVersionAgentTool(ctx context.Context, agentID int64, vAgentTools []entity.VersionAgentTool) (tools []*entity.ToolInfo, err error) { +func (t *toolRepoImpl) MGetVersionAgentTool(ctx context.Context, agentID int64, vAgentTools []model.VersionAgentTool) (tools []*entity.ToolInfo, err error) { return t.agentToolVersionDAO.MGet(ctx, agentID, vAgentTools) } diff --git a/backend/domain/plugin/repository/tool_repository.go b/backend/domain/plugin/repository/tool_repository.go index f6cab51d8..71d69665a 100644 --- a/backend/domain/plugin/repository/tool_repository.go +++ b/backend/domain/plugin/repository/tool_repository.go @@ -19,6 +19,8 @@ package repository import ( "context" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" ) @@ -29,15 +31,15 @@ type ToolRepository interface { GetDraftTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, exist bool, err error) MGetDraftTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) - GetDraftToolWithAPI(ctx context.Context, pluginID int64, api entity.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) - MGetDraftToolWithAPI(ctx context.Context, pluginID int64, apis []entity.UniqueToolAPI, opts ...ToolSelectedOptions) (tools map[entity.UniqueToolAPI]*entity.ToolInfo, err error) + GetDraftToolWithAPI(ctx context.Context, pluginID int64, api dto.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) + MGetDraftToolWithAPI(ctx context.Context, pluginID int64, apis []dto.UniqueToolAPI, opts ...ToolSelectedOptions) (tools map[dto.UniqueToolAPI]*entity.ToolInfo, err error) DeleteDraftTool(ctx context.Context, toolID int64) (err error) GetOnlineTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, exist bool, err error) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) - GetVersionTool(ctx context.Context, vTool entity.VersionTool) (tool *entity.ToolInfo, exist bool, err error) - MGetVersionTools(ctx context.Context, vTools []entity.VersionTool) (tools []*entity.ToolInfo, err error) + GetVersionTool(ctx context.Context, vTool model.VersionTool) (tool *entity.ToolInfo, exist bool, err error) + MGetVersionTools(ctx context.Context, vTools []model.VersionTool) (tools []*entity.ToolInfo, err error) BindDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) @@ -48,14 +50,14 @@ type ToolRepository interface { GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) - GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) + GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool model.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) GetVersionAgentToolWithToolName(ctx context.Context, req *GetVersionAgentToolWithToolNameRequest) (tool *entity.ToolInfo, exist bool, err error) - MGetVersionAgentTool(ctx context.Context, agentID int64, vAgentTools []entity.VersionAgentTool) (tools []*entity.ToolInfo, err error) + MGetVersionAgentTool(ctx context.Context, agentID int64, vAgentTools []model.VersionAgentTool) (tools []*entity.ToolInfo, err error) BatchCreateVersionAgentTools(ctx context.Context, agentID int64, agentVersion string, tools []*entity.ToolInfo) (err error) GetPluginAllDraftTools(ctx context.Context, pluginID int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) GetPluginAllOnlineTools(ctx context.Context, pluginID int64) (tools []*entity.ToolInfo, err error) - ListPluginDraftTools(ctx context.Context, pluginID int64, pageInfo entity.PageInfo) (tools []*entity.ToolInfo, total int64, err error) + ListPluginDraftTools(ctx context.Context, pluginID int64, pageInfo dto.PageInfo) (tools []*entity.ToolInfo, total int64, err error) } type GetVersionAgentToolWithToolNameRequest struct { diff --git a/backend/domain/plugin/service/agent_tool.go b/backend/domain/plugin/service/agent_tool.go index 45ded75f4..0ddf95fbf 100644 --- a/backend/domain/plugin/service/agent_tool.go +++ b/backend/domain/plugin/service/agent_tool.go @@ -24,7 +24,8 @@ import ( "github.com/getkin/kin-openapi/openapi3" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" @@ -101,7 +102,7 @@ func (p *pluginServiceImpl) MGetAgentTools(ctx context.Context, req *model.MGetA return tools, nil } - vTools := make([]entity.VersionAgentTool, 0, len(existMap)) + vTools := make([]model.VersionAgentTool, 0, len(existMap)) for _, v := range req.VersionAgentTools { if existMap[v.ToolID] { vTools = append(vTools, v) @@ -162,17 +163,17 @@ func (p *pluginServiceImpl) UpdateBotDefaultParams(ctx context.Context, req *dto } if req.RequestBody != nil { - mType, ok := req.RequestBody.Value.Content[model.MediaTypeJson] + mType, ok := req.RequestBody.Value.Content[consts.MediaTypeJson] if !ok { - return fmt.Errorf("the '%s' media type is not defined in request body", model.MediaTypeJson) + return fmt.Errorf("the '%s' media type is not defined in request body", consts.MediaTypeJson) } if op.RequestBody == nil || op.RequestBody.Value == nil { - op.RequestBody = entity.DefaultOpenapi3RequestBody() + op.RequestBody = model.DefaultOpenapi3RequestBody() } if op.RequestBody.Value.Content == nil { op.RequestBody.Value.Content = map[string]*openapi3.MediaType{} } - op.RequestBody.Value.Content[model.MediaTypeJson] = mType + op.RequestBody.Value.Content[consts.MediaTypeJson] = mType } if req.Responses != nil { @@ -180,18 +181,18 @@ func (p *pluginServiceImpl) UpdateBotDefaultParams(ctx context.Context, req *dto if !ok { return fmt.Errorf("the '%d' status code is not defined in responses", http.StatusOK) } - newMIMEType, ok := newRespRef.Value.Content[model.MediaTypeJson] + newMIMEType, ok := newRespRef.Value.Content[consts.MediaTypeJson] if !ok { - return fmt.Errorf("the '%s' media type is not defined in responses", model.MediaTypeJson) + return fmt.Errorf("the '%s' media type is not defined in responses", consts.MediaTypeJson) } if op.Responses == nil { - op.Responses = entity.DefaultOpenapi3Responses() + op.Responses = model.DefaultOpenapi3Responses() } oldRespRef, ok := op.Responses[strconv.Itoa(http.StatusOK)] if !ok { - oldRespRef = entity.DefaultOpenapi3Responses()[strconv.Itoa(http.StatusOK)] + oldRespRef = model.DefaultOpenapi3Responses()[strconv.Itoa(http.StatusOK)] op.Responses[strconv.Itoa(http.StatusOK)] = oldRespRef } @@ -199,7 +200,7 @@ func (p *pluginServiceImpl) UpdateBotDefaultParams(ctx context.Context, req *dto oldRespRef.Value.Content = map[string]*openapi3.MediaType{} } - oldRespRef.Value.Content[model.MediaTypeJson] = newMIMEType + oldRespRef.Value.Content[consts.MediaTypeJson] = newMIMEType } updatedTool := &entity.ToolInfo{ @@ -272,12 +273,12 @@ func mergeParameters(ctx context.Context, dest, src openapi3.Parameters) (openap dv.Extensions = make(map[string]any) } - if v, ok := sv.Extensions[model.APISchemaExtendLocalDisable]; ok { - dv.Extensions[model.APISchemaExtendLocalDisable] = v + if v, ok := sv.Extensions[consts.APISchemaExtendLocalDisable]; ok { + dv.Extensions[consts.APISchemaExtendLocalDisable] = v } - if v, ok := sv.Extensions[model.APISchemaExtendVariableRef]; ok { - dv.Extensions[model.APISchemaExtendVariableRef] = v + if v, ok := sv.Extensions[consts.APISchemaExtendVariableRef]; ok { + dv.Extensions[consts.APISchemaExtendVariableRef] = v } dv.Default = sv.Default @@ -312,11 +313,11 @@ func mergeMediaSchema(ctx context.Context, dest, src *openapi3.Schema) (*openapi if dest.Extensions == nil { dest.Extensions = map[string]any{} } - if v, ok := src.Extensions[model.APISchemaExtendLocalDisable]; ok { - dest.Extensions[model.APISchemaExtendLocalDisable] = v + if v, ok := src.Extensions[consts.APISchemaExtendLocalDisable]; ok { + dest.Extensions[consts.APISchemaExtendLocalDisable] = v } - if v, ok := src.Extensions[model.APISchemaExtendVariableRef]; ok { - dest.Extensions[model.APISchemaExtendVariableRef] = v + if v, ok := src.Extensions[consts.APISchemaExtendVariableRef]; ok { + dest.Extensions[consts.APISchemaExtendVariableRef] = v } dest.Default = src.Default diff --git a/backend/domain/plugin/service/exec_tool.go b/backend/domain/plugin/service/exec_tool.go index bdbbc6280..d516b5262 100644 --- a/backend/domain/plugin/service/exec_tool.go +++ b/backend/domain/plugin/service/exec_tool.go @@ -29,7 +29,9 @@ import ( "github.com/getkin/kin-openapi/openapi3" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/service/tool" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" @@ -39,7 +41,7 @@ import ( "github.com/coze-dev/coze-studio/backend/types/errno" ) -func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...entity.ExecuteToolOpt) (resp *model.ExecuteToolResponse, err error) { +func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...model.ExecuteToolOpt) (resp *model.ExecuteToolResponse, err error) { opt := &model.ExecuteToolOption{} for _, fn := range opts { fn(opt) @@ -61,7 +63,7 @@ func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *model.ExecuteT return nil, errorx.Wrapf(err, "execute tool failed") } - if req.ExecScene == model.ExecSceneOfToolDebug { + if req.ExecScene == consts.ExecSceneOfToolDebug { err = p.toolRepo.UpdateDraftTool(ctx, &entity.ToolInfo{ ID: req.ToolID, DebugStatus: ptr.Of(common.APIDebugStatus_DebugPassed), @@ -92,30 +94,30 @@ func (p *pluginServiceImpl) ExecuteTool(ctx context.Context, req *model.ExecuteT func (p *pluginServiceImpl) acquireAccessTokenIfNeed(ctx context.Context, req *model.ExecuteToolRequest, authInfo *model.AuthV2, schema *model.Openapi3Operation) (accessToken string, authURL string, err error) { - if authInfo.Type == model.AuthzTypeOfNone { + if authInfo.Type == consts.AuthzTypeOfNone { return "", "", nil } - authMode := model.ToolAuthModeOfRequired - if tmp, ok := schema.Extensions[model.APISchemaExtendAuthMode].(string); ok { - authMode = model.ToolAuthMode(tmp) + authMode := consts.ToolAuthModeOfRequired + if tmp, ok := schema.Extensions[consts.APISchemaExtendAuthMode].(string); ok { + authMode = consts.ToolAuthMode(tmp) } - if authMode == model.ToolAuthModeOfDisabled { + if authMode == consts.ToolAuthModeOfDisabled { return "", "", nil } - if authInfo.SubType == model.AuthzSubTypeOfOAuthAuthorizationCode { - authorizationCode := &entity.AuthorizationCodeInfo{ - Meta: &entity.AuthorizationCodeMeta{ + if authInfo.SubType == consts.AuthzSubTypeOfOAuthAuthorizationCode { + authorizationCode := &dto.AuthorizationCodeInfo{ + Meta: &dto.AuthorizationCodeMeta{ UserID: req.UserID, PluginID: req.PluginID, - IsDraft: req.ExecScene == model.ExecSceneOfToolDebug, + IsDraft: req.ExecScene == consts.ExecSceneOfToolDebug, }, Config: authInfo.AuthOfOAuthAuthorizationCode, } - accessToken, err = p.GetAccessToken(ctx, &entity.OAuthInfo{ + accessToken, err = p.GetAccessToken(ctx, &dto.OAuthInfo{ OAuthMode: authInfo.SubType, AuthorizationCode: authorizationCode, }) @@ -142,13 +144,13 @@ func (p *pluginServiceImpl) buildToolExecutor(ctx context.Context, req *model.Ex tl *entity.ToolInfo ) switch req.ExecScene { - case model.ExecSceneOfOnlineAgent: + case consts.ExecSceneOfOnlineAgent: pl, tl, err = p.getOnlineAgentPluginInfo(ctx, req, opt) - case model.ExecSceneOfDraftAgent: + case consts.ExecSceneOfDraftAgent: pl, tl, err = p.getDraftAgentPluginInfo(ctx, req, opt) - case model.ExecSceneOfToolDebug: + case consts.ExecSceneOfToolDebug: pl, tl, err = p.getToolDebugPluginInfo(ctx, req, opt) - case model.ExecSceneOfWorkflow: + case consts.ExecSceneOfWorkflow: pl, tl, err = p.getWorkflowPluginInfo(ctx, req, opt) default: return nil, fmt.Errorf("invalid execute scene '%s'", req.ExecScene) @@ -207,7 +209,7 @@ func (p *pluginServiceImpl) getDraftAgentPluginInfo(ctx context.Context, req *mo return nil, nil, errorx.New(errno.ErrPluginRecordNotFound) } } else { - onlinePlugin, exist, err = p.pluginRepo.GetVersionPlugin(ctx, entity.VersionPlugin{ + onlinePlugin, exist, err = p.pluginRepo.GetVersionPlugin(ctx, model.VersionPlugin{ PluginID: req.PluginID, Version: execOpt.ToolVersion, }) @@ -242,7 +244,7 @@ func (p *pluginServiceImpl) getOnlineAgentPluginInfo(ctx context.Context, req *m return nil, nil, errorx.New(errno.ErrPluginRecordNotFound) } - agentTool, exist, err := p.toolRepo.GetVersionAgentTool(ctx, execOpt.ProjectInfo.ProjectID, entity.VersionAgentTool{ + agentTool, exist, err := p.toolRepo.GetVersionAgentTool(ctx, execOpt.ProjectInfo.ProjectID, model.VersionAgentTool{ ToolID: req.ToolID, AgentVersion: execOpt.ProjectInfo.ProjectVersion, }) @@ -263,7 +265,7 @@ func (p *pluginServiceImpl) getOnlineAgentPluginInfo(ctx context.Context, req *m return nil, nil, errorx.New(errno.ErrPluginRecordNotFound) } } else { - onlinePlugin, exist, err = p.pluginRepo.GetVersionPlugin(ctx, entity.VersionPlugin{ + onlinePlugin, exist, err = p.pluginRepo.GetVersionPlugin(ctx, model.VersionPlugin{ PluginID: req.PluginID, Version: execOpt.ToolVersion, }) @@ -324,7 +326,7 @@ func (p *pluginServiceImpl) getWorkflowPluginInfo(ctx context.Context, req *mode } } else { - pl, exist, err = p.pluginRepo.GetVersionPlugin(ctx, entity.VersionPlugin{ + pl, exist, err = p.pluginRepo.GetVersionPlugin(ctx, model.VersionPlugin{ PluginID: req.PluginID, Version: execOpt.ToolVersion, }) @@ -335,7 +337,7 @@ func (p *pluginServiceImpl) getWorkflowPluginInfo(ctx context.Context, req *mode return nil, nil, errorx.New(errno.ErrPluginRecordNotFound) } - tl, exist, err = p.toolRepo.GetVersionTool(ctx, entity.VersionTool{ + tl, exist, err = p.toolRepo.GetVersionTool(ctx, model.VersionTool{ ToolID: req.ToolID, Version: execOpt.ToolVersion, }) @@ -355,9 +357,9 @@ func (p *pluginServiceImpl) getToolDebugPluginInfo(ctx context.Context, req *mod _ *model.ExecuteToolOption) (pl *entity.PluginInfo, tl *entity.ToolInfo, err error) { if req.ExecDraftTool { - tl, exist, err := p.toolRepo.GetDraftTool(ctx, req.ToolID) - if err != nil { - return nil, nil, errorx.Wrapf(err, "GetDraftTool failed, toolID=%d", req.ToolID) + tool, exist, mErr := p.toolRepo.GetDraftTool(ctx, req.ToolID) + if mErr != nil { + return nil, nil, errorx.Wrapf(mErr, "GetDraftTool failed, toolID=%d", req.ToolID) } if !exist { return nil, nil, errorx.New(errno.ErrPluginRecordNotFound) @@ -371,11 +373,11 @@ func (p *pluginServiceImpl) getToolDebugPluginInfo(ctx context.Context, req *mod return nil, nil, errorx.New(errno.ErrPluginRecordNotFound) } - if tl.GetActivatedStatus() != model.ActivateTool { - return nil, nil, errorx.New(errno.ErrPluginDeactivatedTool, errorx.KV(errno.PluginMsgKey, tl.GetName())) + if tool.GetActivatedStatus() != consts.ActivateTool { + return nil, nil, errorx.New(errno.ErrPluginDeactivatedTool, errorx.KV(errno.PluginMsgKey, tool.GetName())) } - return pl, tl, nil + return pl, tool, nil } tl, exist, err := p.toolRepo.GetOnlineTool(ctx, req.ToolID) @@ -405,14 +407,14 @@ func (p *pluginServiceImpl) genToolResponseSchema(ctx context.Context, rawResp s "the type of response only supports json map")) } - resp := entity.DefaultOpenapi3Responses() + resp := model.DefaultOpenapi3Responses() respSchema := parseResponseToBodySchemaRef(ctx, valMap) if respSchema == nil { return resp, nil } - resp[strconv.Itoa(http.StatusOK)].Value.Content[model.MediaTypeJson].Schema = respSchema + resp[strconv.Itoa(http.StatusOK)].Value.Content[consts.MediaTypeJson].Schema = respSchema return resp, nil } @@ -495,26 +497,26 @@ type ExecuteResponse struct { } type toolExecutor struct { - execScene model.ExecuteScene + execScene consts.ExecuteScene userID string conversationID int64 plugin *entity.PluginInfo tool *entity.ToolInfo - projectInfo *entity.ProjectInfo - invalidRespProcessStrategy model.InvalidResponseProcessStrategy + projectInfo *model.ProjectInfo + invalidRespProcessStrategy consts.InvalidResponseProcessStrategy oss storage.Storage } func newToolInvocation(t *toolExecutor) tool.Invocation { switch t.plugin.Manifest.API.Type { - case model.PluginTypeOfCloud: + case consts.PluginTypeOfCloud: return tool.NewHttpCallImpl(t.conversationID) - case model.PluginTypeOfMCP: + case consts.PluginTypeOfMCP: return tool.NewMcpCallImpl() - case model.PluginTypeOfCustom: + case consts.PluginTypeOfCustom: return tool.NewCustomCallImpl() default: // default to http call return tool.NewHttpCallImpl(t.conversationID) @@ -549,7 +551,7 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson, accessToken return nil, err } - if t.execScene != model.ExecSceneOfToolDebug { // debug + if t.execScene != consts.ExecSceneOfToolDebug { // debug // only assemble file uri to url in debug scene err = invocation.AssembleFileURIToURL(ctx, t.oss) if err != nil { @@ -598,9 +600,9 @@ func (t *toolExecutor) processResponse(ctx context.Context, rawResp string) (tri if !ok { return "", fmt.Errorf("the '%d' status code is not defined in responses", http.StatusOK) } - mType, ok := resp.Value.Content[model.MediaTypeJson] // only support application/json + mType, ok := resp.Value.Content[consts.MediaTypeJson] // only support application/json if !ok { - return "", fmt.Errorf("the '%s' media type is not defined in response", model.MediaTypeJson) + return "", fmt.Errorf("the '%s' media type is not defined in response", consts.MediaTypeJson) } decoder := sonic.ConfigDefault.NewDecoder(bytes.NewBufferString(rawResp)) @@ -619,19 +621,19 @@ func (t *toolExecutor) processResponse(ctx context.Context, rawResp string) (tri var trimmedRespMap map[string]any switch t.invalidRespProcessStrategy { - case model.InvalidResponseProcessStrategyOfReturnRaw: + case consts.InvalidResponseProcessStrategyOfReturnRaw: trimmedRespMap, err = t.processWithInvalidRespProcessStrategyOfReturnRaw(ctx, respMap, schemaVal) if err != nil { return "", err } - case model.InvalidResponseProcessStrategyOfReturnDefault: + case consts.InvalidResponseProcessStrategyOfReturnDefault: trimmedRespMap, err = t.processWithInvalidRespProcessStrategyOfReturnDefault(ctx, respMap, schemaVal) if err != nil { return "", err } - case model.InvalidResponseProcessStrategyOfReturnErr: + case consts.InvalidResponseProcessStrategyOfReturnErr: trimmedRespMap, err = t.processWithInvalidRespProcessStrategyOfReturnErr(ctx, respMap, schemaVal) if err != nil { return "", err @@ -892,10 +894,10 @@ func (t *toolExecutor) disabledParam(schemaVal *openapi3.Schema) bool { return false } globalDisable, localDisable := false, false - if v, ok := schemaVal.Extensions[model.APISchemaExtendLocalDisable]; ok { + if v, ok := schemaVal.Extensions[consts.APISchemaExtendLocalDisable]; ok { localDisable = v.(bool) } - if v, ok := schemaVal.Extensions[model.APISchemaExtendGlobalDisable]; ok { + if v, ok := schemaVal.Extensions[consts.APISchemaExtendGlobalDisable]; ok { globalDisable = v.(bool) } return globalDisable || localDisable diff --git a/backend/domain/plugin/service/exec_tool_test.go b/backend/domain/plugin/service/exec_tool_test.go index ef7fb7dfa..6b8a20149 100644 --- a/backend/domain/plugin/service/exec_tool_test.go +++ b/backend/domain/plugin/service/exec_tool_test.go @@ -18,6 +18,7 @@ package service import ( "bytes" + "context" "encoding/json" "errors" "testing" @@ -27,13 +28,13 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/assert" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" "github.com/coze-dev/coze-studio/backend/pkg/errorx" ) func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnDefault(t *testing.T) { executor := &toolExecutor{ - invalidRespProcessStrategy: model.InvalidResponseProcessStrategyOfReturnDefault, + invalidRespProcessStrategy: consts.InvalidResponseProcessStrategyOfReturnDefault, } paramVal := ` @@ -108,7 +109,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnDefault(t *tes }, } - processedParamValMap, err := executor.processWithInvalidRespProcessStrategyOfReturnDefault(nil, paramValMap, paramSchema) + processedParamValMap, err := executor.processWithInvalidRespProcessStrategyOfReturnDefault(context.Background(), paramValMap, paramSchema) assert.NoError(t, err) assert.NotNil(t, processedParamValMap) assert.Equal(t, int64(1), processedParamValMap["a1"]) @@ -120,7 +121,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnDefault(t *tes func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing.T) { executor := &toolExecutor{ - invalidRespProcessStrategy: model.InvalidResponseProcessStrategyOfReturnErr, + invalidRespProcessStrategy: consts.InvalidResponseProcessStrategyOfReturnErr, } mockey.PatchConvey("integer", t, func() { @@ -146,7 +147,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing }, }, } - _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(nil, paramValMap, paramSchema) + _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(context.Background(), paramValMap, paramSchema) var customErr errorx.StatusError assert.True(t, errors.As(err, &customErr)) assert.Equal(t, "execute tool failed : expected 'a' to be of type 'string', but got 'json.Number'", customErr.Msg()) @@ -161,7 +162,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing }, }, } - _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(nil, paramValMap, paramSchema) + _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(context.Background(), paramValMap, paramSchema) assert.NoError(t, err) }) @@ -188,7 +189,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing }, }, } - _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(nil, paramValMap, paramSchema) + _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(context.Background(), paramValMap, paramSchema) var customErr errorx.StatusError assert.True(t, errors.As(err, &customErr)) assert.Equal(t, "execute tool failed : expected 'a' to be of type 'integer', but got 'string'", customErr.Msg()) @@ -203,7 +204,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing }, }, } - _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(nil, paramValMap, paramSchema) + _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(context.Background(), paramValMap, paramSchema) assert.NoError(t, err) }) @@ -230,7 +231,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing }, }, } - _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(nil, paramValMap, paramSchema) + _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(context.Background(), paramValMap, paramSchema) var customErr errorx.StatusError assert.True(t, errors.As(err, &customErr)) assert.Equal(t, "execute tool failed : expected 'a' to be of type 'string', but got 'bool'", customErr.Msg()) @@ -245,7 +246,7 @@ func TestToolExecutorProcessWithInvalidRespProcessStrategyOfReturnErr(t *testing }, }, } - _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(nil, paramValMap, paramSchema) + _, err = executor.processWithInvalidRespProcessStrategyOfReturnErr(context.Background(), paramValMap, paramSchema) assert.NoError(t, err) }) } diff --git a/backend/domain/plugin/dto/plugin_author.go b/backend/domain/plugin/service/plugin_auth.go similarity index 70% rename from backend/domain/plugin/dto/plugin_author.go rename to backend/domain/plugin/service/plugin_auth.go index 0c4d969f3..45c30c5eb 100644 --- a/backend/domain/plugin/dto/plugin_author.go +++ b/backend/domain/plugin/service/plugin_auth.go @@ -13,50 +13,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package dto +package service import ( "fmt" "strings" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/types/errno" ) -type PluginAuthInfo struct { - AuthzType *model.AuthzType - Location *model.HTTPParamLocation - Key *string - ServiceToken *string - OAuthInfo *string - AuthzSubType *model.AuthzSubType - AuthzPayload *string +type pluginAuthConverter struct { + PluginAuthInfo *dto.PluginAuthInfo } -// TODO(@fanlv): change to DTO + Service -func (p PluginAuthInfo) ToAuthV2() (*model.AuthV2, error) { +func newPluginAuthConverter(pluginAuthInfo *dto.PluginAuthInfo) *pluginAuthConverter { + return &pluginAuthConverter{ + PluginAuthInfo: pluginAuthInfo, + } +} + +func (s *pluginAuthConverter) ToAuthV2() (*model.AuthV2, error) { + p := s.PluginAuthInfo if p.AuthzType == nil { return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "auth type is required")) } switch *p.AuthzType { - case model.AuthzTypeOfNone: + case consts.AuthzTypeOfNone: return &model.AuthV2{ - Type: model.AuthzTypeOfNone, + Type: consts.AuthzTypeOfNone, }, nil - case model.AuthzTypeOfOAuth: - m, err := p.authOfOAuthToAuthV2() + case consts.AuthzTypeOfOAuth: + m, err := s.authOfOAuthToAuthV2() if err != nil { return nil, err } return m, nil - case model.AuthzTypeOfService: - m, err := p.authOfServiceToAuthV2() + case consts.AuthzTypeOfService: + m, err := s.authOfServiceToAuthV2() if err != nil { return nil, err } @@ -68,7 +69,8 @@ func (p PluginAuthInfo) ToAuthV2() (*model.AuthV2, error) { } } -func (p PluginAuthInfo) authOfOAuthToAuthV2() (*model.AuthV2, error) { +func (s *pluginAuthConverter) authOfOAuthToAuthV2() (*model.AuthV2, error) { + p := s.PluginAuthInfo if p.AuthzSubType == nil { return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "sub-auth type is required")) } @@ -83,7 +85,7 @@ func (p PluginAuthInfo) authOfOAuthToAuthV2() (*model.AuthV2, error) { return nil, errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "invalid oauth info")) } - if *p.AuthzSubType == model.AuthzSubTypeOfOAuthClientCredentials { + if *p.AuthzSubType == consts.AuthzSubTypeOfOAuthClientCredentials { _oauthInfo := &model.OAuthClientCredentialsConfig{ ClientID: oauthInfo["client_id"], ClientSecret: oauthInfo["client_secret"], @@ -96,16 +98,16 @@ func (p PluginAuthInfo) authOfOAuthToAuthV2() (*model.AuthV2, error) { } return &model.AuthV2{ - Type: model.AuthzTypeOfOAuth, - SubType: model.AuthzSubTypeOfOAuthClientCredentials, + Type: consts.AuthzTypeOfOAuth, + SubType: consts.AuthzSubTypeOfOAuthClientCredentials, Payload: str, AuthOfOAuthClientCredentials: _oauthInfo, }, nil } - if *p.AuthzSubType == model.AuthzSubTypeOfOAuthAuthorizationCode { + if *p.AuthzSubType == consts.AuthzSubTypeOfOAuthAuthorizationCode { contentType := oauthInfo["authorization_content_type"] - if contentType != model.MediaTypeJson { // only support application/json + if contentType != consts.MediaTypeJson { // only support application/json return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey, "the type '%s' of authorization content is invalid", contentType)) } @@ -125,8 +127,8 @@ func (p PluginAuthInfo) authOfOAuthToAuthV2() (*model.AuthV2, error) { } return &model.AuthV2{ - Type: model.AuthzTypeOfOAuth, - SubType: model.AuthzSubTypeOfOAuthAuthorizationCode, + Type: consts.AuthzTypeOfOAuth, + SubType: consts.AuthzSubTypeOfOAuthAuthorizationCode, Payload: str, AuthOfOAuthAuthorizationCode: _oauthInfo, }, nil @@ -136,12 +138,13 @@ func (p PluginAuthInfo) authOfOAuthToAuthV2() (*model.AuthV2, error) { "the type '%s' of sub-auth is invalid", *p.AuthzSubType)) } -func (p PluginAuthInfo) authOfServiceToAuthV2() (*model.AuthV2, error) { +func (s *pluginAuthConverter) authOfServiceToAuthV2() (*model.AuthV2, error) { + p := s.PluginAuthInfo if p.AuthzSubType == nil { return nil, fmt.Errorf("sub-auth type is required") } - if *p.AuthzSubType == model.AuthzSubTypeOfServiceAPIToken { + if *p.AuthzSubType == consts.AuthzSubTypeOfServiceAPIToken { if p.Location == nil { return nil, fmt.Errorf("'Location' of sub-auth is required") } @@ -154,7 +157,7 @@ func (p PluginAuthInfo) authOfServiceToAuthV2() (*model.AuthV2, error) { tokenAuth := &model.AuthOfAPIToken{ ServiceToken: *p.ServiceToken, - Location: model.HTTPParamLocation(strings.ToLower(string(*p.Location))), + Location: consts.HTTPParamLocation(strings.ToLower(string(*p.Location))), Key: *p.Key, } @@ -164,8 +167,8 @@ func (p PluginAuthInfo) authOfServiceToAuthV2() (*model.AuthV2, error) { } return &model.AuthV2{ - Type: model.AuthzTypeOfService, - SubType: model.AuthzSubTypeOfServiceAPIToken, + Type: consts.AuthzTypeOfService, + SubType: consts.AuthzSubTypeOfServiceAPIToken, Payload: str, AuthOfAPIToken: tokenAuth, }, nil diff --git a/backend/domain/plugin/service/plugin_draft.go b/backend/domain/plugin/service/plugin_draft.go index 5627247d6..ca5c69277 100644 --- a/backend/domain/plugin/service/plugin_draft.go +++ b/backend/domain/plugin/service/plugin_draft.go @@ -29,10 +29,11 @@ import ( "gopkg.in/yaml.v3" searchModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search" - common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" @@ -46,36 +47,36 @@ import ( ) func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *dto.CreateDraftPluginRequest) (pluginID int64, err error) { - mf := entity.NewDefaultPluginManifest() - mf.CommonParams = map[model.HTTPParamLocation][]*plugin_develop_common.CommonParamSchema{} + mf := model.NewDefaultPluginManifest() + mf.CommonParams = map[consts.HTTPParamLocation][]*common.CommonParamSchema{} mf.NameForHuman = req.Name mf.NameForModel = req.Name mf.DescriptionForHuman = req.Desc mf.DescriptionForModel = req.Desc - mf.API.Type, _ = model.ToPluginType(req.PluginType) + mf.API.Type, _ = convert.ToPluginType(req.PluginType) mf.LogoURL = req.IconURI - authV2, err := req.AuthInfo.ToAuthV2() + authV2, err := newPluginAuthConverter(req.AuthInfo).ToAuthV2() if err != nil { return 0, err } mf.Auth = authV2 for loc, params := range req.CommonParams { - location, ok := model.ToHTTPParamLocation(loc) + location, ok := convert.ToHTTPParamLocation(loc) if !ok { return 0, fmt.Errorf("invalid location '%s'", loc.String()) } for _, param := range params { mf.CommonParams[location] = append(mf.CommonParams[location], - &plugin_develop_common.CommonParamSchema{ + &common.CommonParamSchema{ Name: param.Name, Value: param.Value, }) } } - doc := entity.NewDefaultOpenapiDoc() + doc := model.NewDefaultOpenapiDoc() doc.Servers = append(doc.Servers, &openapi3.Server{ URL: req.ServerURL, }) @@ -133,13 +134,13 @@ func (p *pluginServiceImpl) MGetDraftPlugins(ctx context.Context, pluginIDs []in func (p *pluginServiceImpl) ListDraftPlugins(ctx context.Context, req *dto.ListDraftPluginsRequest) (resp *dto.ListDraftPluginsResponse, err error) { if req.PageInfo.Name == nil || *req.PageInfo.Name == "" { - res, err := p.pluginRepo.ListDraftPlugins(ctx, &repository.ListDraftPluginsRequest{ + res, mErr := p.pluginRepo.ListDraftPlugins(ctx, &repository.ListDraftPluginsRequest{ SpaceID: req.SpaceID, APPID: req.APPID, PageInfo: req.PageInfo, }) - if err != nil { - return nil, errorx.Wrapf(err, "ListDraftPlugins failed, spaceID=%d, appID=%d", req.SpaceID, req.APPID) + if mErr != nil { + return nil, errorx.Wrapf(mErr, "ListDraftPlugins failed, spaceID=%d, appID=%d", req.SpaceID, req.APPID) } return &dto.ListDraftPluginsResponse{ @@ -157,7 +158,7 @@ func (p *pluginServiceImpl) ListDraftPlugins(ctx context.Context, req *dto.ListD resCommon.ResType_Plugin, }, OrderFiledName: func() string { - if req.PageInfo.SortBy == nil || *req.PageInfo.SortBy != entity.SortByCreatedAt { + if req.PageInfo.SortBy == nil || *req.PageInfo.SortBy != dto.SortByCreatedAt { return searchModel.FieldOfUpdateTime } return searchModel.FieldOfCreateTime @@ -235,12 +236,12 @@ func (p *pluginServiceImpl) UpdateDraftPluginWithCode(ctx context.Context, req * return err } - apiSchemas := make(map[entity.UniqueToolAPI]*model.Openapi3Operation, len(doc.Paths)) - apis := make([]entity.UniqueToolAPI, 0, len(doc.Paths)) + apiSchemas := make(map[dto.UniqueToolAPI]*model.Openapi3Operation, len(doc.Paths)) + apis := make([]dto.UniqueToolAPI, 0, len(doc.Paths)) for subURL, pathItem := range doc.Paths { for method, op := range pathItem.Operations() { - api := entity.UniqueToolAPI{ + api := dto.UniqueToolAPI{ SubURL: subURL, Method: method, } @@ -268,8 +269,8 @@ func (p *pluginServiceImpl) UpdateDraftPluginWithCode(ctx context.Context, req * } } - oldDraftToolsMap := slices.ToMap(oldDraftTools, func(e *entity.ToolInfo) (entity.UniqueToolAPI, *entity.ToolInfo) { - return entity.UniqueToolAPI{ + oldDraftToolsMap := slices.ToMap(oldDraftTools, func(e *entity.ToolInfo) (dto.UniqueToolAPI, *entity.ToolInfo) { + return dto.UniqueToolAPI{ SubURL: e.GetSubURL(), Method: e.GetMethod(), }, e @@ -280,7 +281,7 @@ func (p *pluginServiceImpl) UpdateDraftPluginWithCode(ctx context.Context, req * _, ok := apiSchemas[api] if !ok { oldTool.DebugStatus = ptr.Of(common.APIDebugStatus_DebugWaiting) - oldTool.ActivatedStatus = ptr.Of(model.DeactivateTool) + oldTool.ActivatedStatus = ptr.Of(consts.DeactivateTool) } } @@ -288,7 +289,7 @@ func (p *pluginServiceImpl) UpdateDraftPluginWithCode(ctx context.Context, req * for api, newOp := range apiSchemas { oldTool, ok := oldDraftToolsMap[api] if ok { // 2. Update tool - > Overlay - oldTool.ActivatedStatus = ptr.Of(model.ActivateTool) + oldTool.ActivatedStatus = ptr.Of(consts.ActivateTool) oldTool.Operation = newOp if needResetDebugStatusTool(ctx, newOp, oldTool.Operation) { oldTool.DebugStatus = ptr.Of(common.APIDebugStatus_DebugWaiting) @@ -299,7 +300,7 @@ func (p *pluginServiceImpl) UpdateDraftPluginWithCode(ctx context.Context, req * // 3. New tools newDraftTools = append(newDraftTools, &entity.ToolInfo{ PluginID: req.PluginID, - ActivatedStatus: ptr.Of(model.ActivateTool), + ActivatedStatus: ptr.Of(consts.ActivateTool), DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting), SubURL: ptr.Of(api.SubURL), Method: ptr.Of(api.Method), @@ -402,10 +403,10 @@ func isJsonSchemaEqual(nsc, osc *openapi3.Schema) bool { if nsc.Default != osc.Default { return false } - if nsc.Extensions[model.APISchemaExtendAssistType] != osc.Extensions[model.APISchemaExtendAssistType] { + if nsc.Extensions[consts.APISchemaExtendAssistType] != osc.Extensions[consts.APISchemaExtendAssistType] { return false } - if nsc.Extensions[model.APISchemaExtendGlobalDisable] != osc.Extensions[model.APISchemaExtendGlobalDisable] { + if nsc.Extensions[consts.APISchemaExtendGlobalDisable] != osc.Extensions[consts.APISchemaExtendGlobalDisable] { return false } @@ -519,7 +520,7 @@ func updatePluginOpenapiDoc(_ context.Context, doc *model.Openapi3T, req *dto.Up return doc, nil } -func updatePluginManifest(_ context.Context, mf *entity.PluginManifest, req *dto.UpdateDraftPluginRequest) (*entity.PluginManifest, error) { +func updatePluginManifest(_ context.Context, mf *model.PluginManifest, req *dto.UpdateDraftPluginRequest) (*model.PluginManifest, error) { if req.Name != nil { mf.NameForHuman = *req.Name mf.NameForModel = *req.Name @@ -536,16 +537,16 @@ func updatePluginManifest(_ context.Context, mf *entity.PluginManifest, req *dto if len(req.CommonParams) > 0 { if mf.CommonParams == nil { - mf.CommonParams = make(map[model.HTTPParamLocation][]*plugin_develop_common.CommonParamSchema, len(req.CommonParams)) + mf.CommonParams = make(map[consts.HTTPParamLocation][]*common.CommonParamSchema, len(req.CommonParams)) } for loc, params := range req.CommonParams { - location, ok := model.ToHTTPParamLocation(loc) + location, ok := convert.ToHTTPParamLocation(loc) if !ok { return nil, fmt.Errorf("invalid location '%s'", loc.String()) } - commonParams := make([]*plugin_develop_common.CommonParamSchema, 0, len(params)) + commonParams := make([]*common.CommonParamSchema, 0, len(params)) for _, param := range params { - commonParams = append(commonParams, &plugin_develop_common.CommonParamSchema{ + commonParams = append(commonParams, &common.CommonParamSchema{ Name: param.Name, Value: param.Value, }) @@ -555,7 +556,7 @@ func updatePluginManifest(_ context.Context, mf *entity.PluginManifest, req *dto } if req.AuthInfo != nil { - authV2, err := req.AuthInfo.ToAuthV2() + authV2, err := newPluginAuthConverter(req.AuthInfo).ToAuthV2() if err != nil { return nil, err } @@ -605,25 +606,25 @@ func (p *pluginServiceImpl) UpdateDraftTool(ctx context.Context, req *dto.Update func (p *pluginServiceImpl) updateDraftTool(ctx context.Context, req *dto.UpdateDraftToolRequest, draftTool *entity.ToolInfo) (err error) { if req.Method != nil && req.SubURL != nil { - api := entity.UniqueToolAPI{ + api := dto.UniqueToolAPI{ SubURL: ptr.FromOrDefault(req.SubURL, ""), Method: ptr.FromOrDefault(req.Method, ""), } - existTool, exist, err := p.toolRepo.GetDraftToolWithAPI(ctx, draftTool.PluginID, api) - if err != nil { - return errorx.Wrapf(err, "GetDraftToolWithAPI failed, pluginID=%d, api=%v", draftTool.PluginID, api) + existTool, exist, mErr := p.toolRepo.GetDraftToolWithAPI(ctx, draftTool.PluginID, api) + if mErr != nil { + return errorx.Wrapf(mErr, "GetDraftToolWithAPI failed, pluginID=%d, api=%v", draftTool.PluginID, api) } if exist && draftTool.ID != existTool.ID { return errorx.New(errno.ErrPluginDuplicatedTool, errorx.KVf(errno.PluginMsgKey, "[%s]:%s", api.Method, api.SubURL)) } } - var activatedStatus *model.ActivatedStatus + var activatedStatus *consts.ActivatedStatus if req.Disabled != nil { if *req.Disabled { - activatedStatus = ptr.Of(model.DeactivateTool) + activatedStatus = ptr.Of(consts.DeactivateTool) } else { - activatedStatus = ptr.Of(model.ActivateTool) + activatedStatus = ptr.Of(consts.ActivateTool) } } @@ -647,9 +648,9 @@ func (p *pluginServiceImpl) updateDraftTool(ctx context.Context, req *dto.Update if op.Extensions == nil { op.Extensions = map[string]any{} } - authMode, ok := model.ToAPIAuthMode(req.APIExtend.AuthMode) + authMode, ok := convert.ToAPIAuthMode(req.APIExtend.AuthMode) if ok { - op.Extensions[model.APISchemaExtendAuthMode] = authMode + op.Extensions[consts.APISchemaExtendAuthMode] = authMode } } @@ -662,9 +663,9 @@ func (p *pluginServiceImpl) updateDraftTool(ctx context.Context, req *dto.Update if req.RequestBody == nil { op.RequestBody = draftTool.Operation.RequestBody } else { - mType, ok := req.RequestBody.Value.Content[model.MediaTypeJson] + mType, ok := req.RequestBody.Value.Content[consts.MediaTypeJson] if !ok { - return fmt.Errorf("the '%s' media type is not defined in request body", model.MediaTypeJson) + return fmt.Errorf("the '%s' media type is not defined in request body", consts.MediaTypeJson) } if op.RequestBody == nil || op.RequestBody.Value == nil || op.RequestBody.Value.Content == nil { op.RequestBody = &openapi3.RequestBodyRef{ @@ -673,7 +674,7 @@ func (p *pluginServiceImpl) updateDraftTool(ctx context.Context, req *dto.Update }, } } - op.RequestBody.Value.Content[model.MediaTypeJson] = mType + op.RequestBody.Value.Content[consts.MediaTypeJson] = mType } // update responses @@ -684,9 +685,9 @@ func (p *pluginServiceImpl) updateDraftTool(ctx context.Context, req *dto.Update if !ok { return fmt.Errorf("the '%d' status code is not defined in responses", http.StatusOK) } - newMIMEType, ok := newRespRef.Value.Content[model.MediaTypeJson] + newMIMEType, ok := newRespRef.Value.Content[consts.MediaTypeJson] if !ok { - return fmt.Errorf("the '%s' media type is not defined in responses", model.MediaTypeJson) + return fmt.Errorf("the '%s' media type is not defined in responses", consts.MediaTypeJson) } if op.Responses == nil { @@ -707,7 +708,7 @@ func (p *pluginServiceImpl) updateDraftTool(ctx context.Context, req *dto.Update oldRespRef.Value.Content = map[string]*openapi3.MediaType{} } - oldRespRef.Value.Content[model.MediaTypeJson] = newMIMEType + oldRespRef.Value.Content[consts.MediaTypeJson] = newMIMEType } updatedTool := &entity.ToolInfo{ @@ -822,7 +823,7 @@ func (p *pluginServiceImpl) ConvertToOpenapi3Doc(ctx context.Context, req *dto.C } } -type convertFunc func(ctx context.Context, rawInput string) (*model.Openapi3T, *entity.PluginManifest, error) +type convertFunc func(ctx context.Context, rawInput string) (*model.Openapi3T, *model.PluginManifest, error) func getConvertFunc(ctx context.Context, rawInput string) (convertFunc, common.PluginDataFormat, error) { if strings.HasPrefix(rawInput, "curl") { @@ -857,7 +858,7 @@ func getConvertFunc(ctx context.Context, rawInput string) (convertFunc, common.P return nil, 0, fmt.Errorf("invalid schema") } -func validateConvertResult(ctx context.Context, req *dto.ConvertToOpenapi3DocRequest, doc *model.Openapi3T, mf *entity.PluginManifest) error { +func validateConvertResult(ctx context.Context, req *dto.ConvertToOpenapi3DocRequest, doc *model.Openapi3T, mf *model.PluginManifest) error { if req.PluginServerURL != nil { if doc.Servers[0].URL != *req.PluginServerURL { return errorx.New(errno.ErrPluginConvertProtocolFailed, errorx.KV(errno.PluginMsgKey, "inconsistent API URL prefix")) @@ -883,10 +884,10 @@ func (p *pluginServiceImpl) CreateDraftToolsWithCode(ctx context.Context, req *d return nil, err } - toolAPIs := make([]entity.UniqueToolAPI, 0, len(req.OpenapiDoc.Paths)) + toolAPIs := make([]dto.UniqueToolAPI, 0, len(req.OpenapiDoc.Paths)) for path, item := range req.OpenapiDoc.Paths { for method := range item.Operations() { - toolAPIs = append(toolAPIs, entity.UniqueToolAPI{ + toolAPIs = append(toolAPIs, dto.UniqueToolAPI{ SubURL: path, Method: method, }) @@ -901,7 +902,7 @@ func (p *pluginServiceImpl) CreateDraftToolsWithCode(ctx context.Context, req *d return nil, errorx.Wrapf(err, "MGetDraftToolWithAPI failed, pluginID=%d, apis=%v", req.PluginID, toolAPIs) } - duplicatedTools := make([]entity.UniqueToolAPI, 0, len(existTools)) + duplicatedTools := make([]dto.UniqueToolAPI, 0, len(existTools)) for _, api := range toolAPIs { if _, exist := existTools[api]; exist { duplicatedTools = append(duplicatedTools, api) @@ -921,7 +922,7 @@ func (p *pluginServiceImpl) CreateDraftToolsWithCode(ctx context.Context, req *d PluginID: req.PluginID, Method: ptr.Of(method), SubURL: ptr.Of(path), - ActivatedStatus: ptr.Of(model.ActivateTool), + ActivatedStatus: ptr.Of(consts.ActivateTool), DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting), Operation: model.NewOpenapi3Operation(op), }) diff --git a/backend/domain/plugin/service/plugin_oauth.go b/backend/domain/plugin/service/plugin_oauth.go index 2325abeb9..9488f1b9d 100644 --- a/backend/domain/plugin/service/plugin_oauth.go +++ b/backend/domain/plugin/service/plugin_oauth.go @@ -28,7 +28,8 @@ import ( "golang.org/x/oauth2" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" @@ -107,7 +108,7 @@ func (p *pluginServiceImpl) processOAuthAccessToken(ctx context.Context) { } } -func (p *pluginServiceImpl) refreshToken(ctx context.Context, info *entity.AuthorizationCodeInfo) { +func (p *pluginServiceImpl) refreshToken(ctx context.Context, info *dto.AuthorizationCodeInfo) { config := oauth2.Config{ ClientID: info.Config.ClientID, ClientSecret: info.Config.ClientSecret, @@ -137,8 +138,8 @@ func (p *pluginServiceImpl) refreshToken(ctx context.Context, info *entity.Autho expiredAtMS = token.Expiry.UnixMilli() } - err = p.oauthRepo.UpsertAuthorizationCode(ctx, &entity.AuthorizationCodeInfo{ - Meta: &entity.AuthorizationCodeMeta{ + err = p.oauthRepo.UpsertAuthorizationCode(ctx, &dto.AuthorizationCodeInfo{ + Meta: &dto.AuthorizationCodeMeta{ UserID: info.Meta.UserID, PluginID: info.Meta.PluginID, IsDraft: info.Meta.IsDraft, @@ -184,13 +185,11 @@ func (p *pluginServiceImpl) refreshTokenFailedHandler(ctx context.Context, recor if err_ != nil { logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", recordID, err_) } - - return } -func (p *pluginServiceImpl) GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (accessToken string, err error) { +func (p *pluginServiceImpl) GetAccessToken(ctx context.Context, oa *dto.OAuthInfo) (accessToken string, err error) { switch oa.OAuthMode { - case model.AuthzSubTypeOfOAuthAuthorizationCode: + case consts.AuthzSubTypeOfOAuthAuthorizationCode: accessToken, err = p.getAccessTokenByAuthorizationCode(ctx, oa.AuthorizationCode) default: return "", fmt.Errorf("invalid oauth mode '%s'", oa.OAuthMode) @@ -202,7 +201,7 @@ func (p *pluginServiceImpl) GetAccessToken(ctx context.Context, oa *entity.OAuth return accessToken, nil } -func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Context, ci *entity.AuthorizationCodeInfo) (accessToken string, err error) { +func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Context, ci *dto.AuthorizationCodeInfo) (accessToken string, err error) { meta := ci.Meta info, exist, err := p.oauthRepo.GetAuthorizationCode(ctx, ci.Meta) if err != nil { @@ -275,7 +274,7 @@ func isValidAuthCodeConfig(o, n *model.OAuthAuthorizationCodeConfig, expireAt, l return true } -func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *entity.OAuthState) (err error) { +func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *dto.OAuthState) (err error) { var plugin *entity.PluginInfo if state.IsDraft { plugin, err = p.GetDraftPlugin(ctx, state.PluginID) @@ -287,7 +286,7 @@ func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *e } authInfo := plugin.GetAuthInfo() - if authInfo.SubType != model.AuthzSubTypeOfOAuthAuthorizationCode { + if authInfo.SubType != consts.AuthzSubTypeOfOAuthAuthorizationCode { return errorx.New(errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "plugin auth type is not oauth authorization code")) } if authInfo.AuthOfOAuthAuthorizationCode == nil { @@ -301,7 +300,7 @@ func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *e return errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "exchange token failed")) } - meta := &entity.AuthorizationCodeMeta{ + meta := &dto.AuthorizationCodeMeta{ UserID: state.UserID, PluginID: state.PluginID, IsDraft: state.IsDraft, @@ -312,9 +311,9 @@ func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *e expiredAtMS = token.Expiry.UnixMilli() } - err = p.saveAccessToken(ctx, &entity.OAuthInfo{ - OAuthMode: model.AuthzSubTypeOfOAuthAuthorizationCode, - AuthorizationCode: &entity.AuthorizationCodeInfo{ + err = p.saveAccessToken(ctx, &dto.OAuthInfo{ + OAuthMode: consts.AuthzSubTypeOfOAuthAuthorizationCode, + AuthorizationCode: &dto.AuthorizationCodeInfo{ Meta: meta, Config: authInfo.AuthOfOAuthAuthorizationCode, AccessToken: token.AccessToken, @@ -331,9 +330,9 @@ func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *e return nil } -func (p *pluginServiceImpl) saveAccessToken(ctx context.Context, oa *entity.OAuthInfo) (err error) { +func (p *pluginServiceImpl) saveAccessToken(ctx context.Context, oa *dto.OAuthInfo) (err error) { switch oa.OAuthMode { - case model.AuthzSubTypeOfOAuthAuthorizationCode: + case consts.AuthzSubTypeOfOAuthAuthorizationCode: err = p.saveAuthCodeAccessToken(ctx, oa.AuthorizationCode) default: return fmt.Errorf("[standardOAuth] invalid oauth mode '%s'", oa.OAuthMode) @@ -342,7 +341,7 @@ func (p *pluginServiceImpl) saveAccessToken(ctx context.Context, oa *entity.OAut return err } -func (p *pluginServiceImpl) saveAuthCodeAccessToken(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) { +func (p *pluginServiceImpl) saveAuthCodeAccessToken(ctx context.Context, info *dto.AuthorizationCodeInfo) (err error) { meta := info.Meta err = p.oauthRepo.UpsertAuthorizationCode(ctx, info) if err != nil { @@ -360,7 +359,7 @@ func getNextTokenRefreshAtMS(expiredAtMS int64) int64 { return time.Now().Add(time.Duration((expiredAtMS-time.Now().UnixMilli())/2) * time.Millisecond).UnixMilli() } -func (p *pluginServiceImpl) RevokeAccessToken(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) { +func (p *pluginServiceImpl) RevokeAccessToken(ctx context.Context, meta *dto.AuthorizationCodeMeta) (err error) { return p.oauthRepo.DeleteAuthorizationCode(ctx, meta) } @@ -374,7 +373,7 @@ func (p *pluginServiceImpl) GetOAuthStatus(ctx context.Context, userID, pluginID } authInfo := pl.GetAuthInfo() - if authInfo.Type == model.AuthzTypeOfNone || authInfo.Type == model.AuthzTypeOfService { + if authInfo.Type == consts.AuthzTypeOfNone || authInfo.Type == consts.AuthzTypeOfService { return &dto.GetOAuthStatusResponse{ IsOauth: false, }, nil @@ -402,15 +401,15 @@ func (p *pluginServiceImpl) GetOAuthStatus(ctx context.Context, userID, pluginID func (p *pluginServiceImpl) getPluginOAuthStatus(ctx context.Context, userID int64, plugin *entity.PluginInfo, isDraft bool) (needAuth bool, authURL string, err error) { authInfo := plugin.GetAuthInfo() - if authInfo.Type != model.AuthzTypeOfOAuth { + if authInfo.Type != consts.AuthzTypeOfOAuth { return false, "", fmt.Errorf("invalid auth type '%v'", authInfo.Type) } - if authInfo.SubType != model.AuthzSubTypeOfOAuthAuthorizationCode { + if authInfo.SubType != consts.AuthzSubTypeOfOAuthAuthorizationCode { return false, "", fmt.Errorf("invalid auth sub type '%v'", authInfo.SubType) } - authCode := &entity.AuthorizationCodeInfo{ - Meta: &entity.AuthorizationCodeMeta{ + authCode := &dto.AuthorizationCodeInfo{ + Meta: &dto.AuthorizationCodeMeta{ UserID: conv.Int64ToStr(userID), PluginID: plugin.ID, IsDraft: isDraft, @@ -418,8 +417,8 @@ func (p *pluginServiceImpl) getPluginOAuthStatus(ctx context.Context, userID int Config: plugin.Manifest.Auth.AuthOfOAuthAuthorizationCode, } - accessToken, err := p.GetAccessToken(ctx, &entity.OAuthInfo{ - OAuthMode: model.AuthzSubTypeOfOAuthAuthorizationCode, + accessToken, err := p.GetAccessToken(ctx, &dto.OAuthInfo{ + OAuthMode: consts.AuthzSubTypeOfOAuthAuthorizationCode, AuthorizationCode: authCode, }) if err != nil { @@ -436,10 +435,10 @@ func (p *pluginServiceImpl) getPluginOAuthStatus(ctx context.Context, userID int return needAuth, authURL, nil } -func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) { +func genAuthURL(info *dto.AuthorizationCodeInfo) (string, error) { config := getStanderOAuthConfig(info.Config) - state := &entity.OAuthState{ + state := &dto.OAuthState{ ClientName: "", UserID: info.Meta.UserID, PluginID: info.Meta.PluginID, @@ -498,7 +497,7 @@ func (p *pluginServiceImpl) GetAgentPluginsOAuthStatus(ctx context.Context, user for _, plugin := range plugins { authInfo := plugin.GetAuthInfo() - if authInfo.Type == model.AuthzTypeOfNone || authInfo.Type == model.AuthzTypeOfService { + if authInfo.Type == consts.AuthzTypeOfNone || authInfo.Type == consts.AuthzTypeOfService { continue } diff --git a/backend/domain/plugin/service/plugin_online.go b/backend/domain/plugin/service/plugin_online.go index fc6964c75..7771e36d8 100644 --- a/backend/domain/plugin/service/plugin_online.go +++ b/backend/domain/plugin/service/plugin_online.go @@ -24,7 +24,8 @@ import ( searchModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search" pluginCommon "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search" pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" @@ -55,11 +56,6 @@ func (p *pluginServiceImpl) MGetOnlinePlugins(ctx context.Context, pluginIDs []i return nil, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs) } - res := make([]*model.PluginInfo, 0, len(plugins)) - for _, pl := range plugins { - res = append(res, pl.PluginInfo) - } - return plugins, nil } @@ -84,7 +80,7 @@ func (p *pluginServiceImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64 return tools, nil } -func (p *pluginServiceImpl) MGetVersionTools(ctx context.Context, versionTools []entity.VersionTool) (tools []*entity.ToolInfo, err error) { +func (p *pluginServiceImpl) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*entity.ToolInfo, err error) { tools, err = p.toolRepo.MGetVersionTools(ctx, versionTools) if err != nil { return nil, errorx.Wrapf(err, "MGetVersionTools failed, versionTools=%v", versionTools) @@ -129,7 +125,7 @@ func (p *pluginServiceImpl) GetAPPAllPlugins(ctx context.Context, appID int64) ( return plugins, nil } -func (p *pluginServiceImpl) MGetVersionPlugins(ctx context.Context, versionPlugins []entity.VersionPlugin) (plugins []*entity.PluginInfo, err error) { +func (p *pluginServiceImpl) MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*entity.PluginInfo, err error) { plugins, err = p.pluginRepo.MGetVersionPlugins(ctx, versionPlugins) if err != nil { return nil, errorx.Wrapf(err, "MGetVersionPlugins failed, versionPlugins=%v", versionPlugins) @@ -138,7 +134,7 @@ func (p *pluginServiceImpl) MGetVersionPlugins(ctx context.Context, versionPlugi return plugins, nil } -func (p *pluginServiceImpl) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { +func (p *pluginServiceImpl) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) { if pageInfo.Name == nil || *pageInfo.Name == "" { plugins, total, err = p.pluginRepo.ListCustomOnlinePlugins(ctx, spaceID, pageInfo) if err != nil { @@ -155,7 +151,7 @@ func (p *pluginServiceImpl) ListCustomOnlinePlugins(ctx context.Context, spaceID resCommon.ResType_Plugin, }, OrderFiledName: func() string { - if pageInfo.SortBy == nil || *pageInfo.SortBy != entity.SortByCreatedAt { + if pageInfo.SortBy == nil || *pageInfo.SortBy != dto.SortByCreatedAt { return searchModel.FieldOfUpdateTime } return searchModel.FieldOfCreateTime @@ -225,7 +221,7 @@ func (p *pluginServiceImpl) CopyPlugin(ctx context.Context, req *dto.CopyPluginR toolMap[tool.ID] = tool } - plugin, tools, err = p.pluginRepo.CopyPlugin(ctx, &repository.CopyPluginRequest{ + plugin, _, err = p.pluginRepo.CopyPlugin(ctx, &repository.CopyPluginRequest{ Plugin: plugin, Tools: tools, }) @@ -247,11 +243,11 @@ func (p *pluginServiceImpl) changePluginAndToolsInfoForCopy(req *dto.CopyPluginR plugin.DeveloperID = req.UserID - if req.CopyScene != model.CopySceneOfAPPDuplicate { + if req.CopyScene != consts.CopySceneOfAPPDuplicate { plugin.SetName(fmt.Sprintf("%s_copy", plugin.GetName())) } - if req.CopyScene == model.CopySceneOfToLibrary { + if req.CopyScene == consts.CopySceneOfToLibrary { const ( defaultVersion = "v0.0.1" defaultVersionDesc = "copy to library" @@ -266,7 +262,7 @@ func (p *pluginServiceImpl) changePluginAndToolsInfoForCopy(req *dto.CopyPluginR } } - if req.CopyScene == model.CopySceneOfToAPP { + if req.CopyScene == consts.CopySceneOfToAPP { plugin.APPID = req.TargetAPPID for _, tool := range tools { @@ -274,27 +270,27 @@ func (p *pluginServiceImpl) changePluginAndToolsInfoForCopy(req *dto.CopyPluginR } } - if req.CopyScene == model.CopySceneOfAPPDuplicate { + if req.CopyScene == consts.CopySceneOfAPPDuplicate { plugin.APPID = req.TargetAPPID } } -func (p *pluginServiceImpl) checkCanCopyPlugin(ctx context.Context, pluginID int64, scene model.CopyScene) (err error) { +func (p *pluginServiceImpl) checkCanCopyPlugin(ctx context.Context, pluginID int64, scene consts.CopyScene) (err error) { switch scene { - case model.CopySceneOfToAPP, model.CopySceneOfDuplicate, model.CopySceneOfAPPDuplicate: + case consts.CopySceneOfToAPP, consts.CopySceneOfDuplicate, consts.CopySceneOfAPPDuplicate: return nil - case model.CopySceneOfToLibrary: + case consts.CopySceneOfToLibrary: return p.checkToolsDebugStatus(ctx, pluginID) default: return fmt.Errorf("unsupported copy scene '%s'", scene) } } -func (p *pluginServiceImpl) getCopySourcePluginAndTools(ctx context.Context, pluginID int64, scene model.CopyScene) (plugin *entity.PluginInfo, tools []*entity.ToolInfo, err error) { +func (p *pluginServiceImpl) getCopySourcePluginAndTools(ctx context.Context, pluginID int64, scene consts.CopyScene) (plugin *entity.PluginInfo, tools []*entity.ToolInfo, err error) { switch scene { - case model.CopySceneOfToAPP: + case consts.CopySceneOfToAPP: return p.getOnlinePluginAndTools(ctx, pluginID) - case model.CopySceneOfToLibrary, model.CopySceneOfDuplicate, model.CopySceneOfAPPDuplicate: + case consts.CopySceneOfToLibrary, consts.CopySceneOfDuplicate, consts.CopySceneOfAPPDuplicate: return p.getDraftPluginAndTools(ctx, pluginID) default: return nil, nil, fmt.Errorf("unsupported copy scene '%s'", scene) diff --git a/backend/domain/plugin/service/plugin_release.go b/backend/domain/plugin/service/plugin_release.go index b36c7f055..025e0e514 100644 --- a/backend/domain/plugin/service/plugin_release.go +++ b/backend/domain/plugin/service/plugin_release.go @@ -25,7 +25,7 @@ import ( "golang.org/x/mod/semver" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" "github.com/coze-dev/coze-studio/backend/pkg/errorx" @@ -207,7 +207,7 @@ func (p *pluginServiceImpl) checkToolsDebugStatus(ctx context.Context, pluginID activatedTools := make([]*entity.ToolInfo, 0, len(res)) for _, tool := range res { - if tool.GetActivatedStatus() == model.DeactivateTool { + if tool.IsDeactivated() { continue } activatedTools = append(activatedTools, tool) diff --git a/backend/domain/plugin/service/service.go b/backend/domain/plugin/service/service.go index eab2bd210..a7af904c6 100644 --- a/backend/domain/plugin/service/service.go +++ b/backend/domain/plugin/service/service.go @@ -19,7 +19,7 @@ package service import ( "context" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" ) @@ -45,8 +45,8 @@ type PluginService interface { MGetOnlinePlugins(ctx context.Context, pluginIDs []int64) (plugins []*entity.PluginInfo, err error) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *model.MGetPluginLatestVersionResponse, err error) GetPluginNextVersion(ctx context.Context, pluginID int64) (version string, err error) - MGetVersionPlugins(ctx context.Context, versionPlugins []entity.VersionPlugin) (plugins []*entity.PluginInfo, err error) - ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) + MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*entity.PluginInfo, err error) + ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo dto.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) // Draft Tool MGetDraftTools(ctx context.Context, toolIDs []int64) (tools []*entity.ToolInfo, err error) @@ -58,7 +58,7 @@ type PluginService interface { // Online Tool GetOnlineTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, err error) MGetOnlineTools(ctx context.Context, toolIDs []int64) (tools []*entity.ToolInfo, err error) - MGetVersionTools(ctx context.Context, versionTools []entity.VersionTool) (tools []*entity.ToolInfo, err error) + MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*entity.ToolInfo, err error) CopyPlugin(ctx context.Context, req *dto.CopyPluginRequest) (resp *dto.CopyPluginResponse, err error) MoveAPPPluginToLibrary(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error) @@ -71,7 +71,7 @@ type PluginService interface { PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) (err error) - ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...entity.ExecuteToolOpt) (resp *model.ExecuteToolResponse, err error) + ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...model.ExecuteToolOpt) (resp *model.ExecuteToolResponse, err error) // Product ListPluginProducts(ctx context.Context, req *dto.ListPluginProductsRequest) (resp *dto.ListPluginProductsResponse, err error) @@ -79,7 +79,7 @@ type PluginService interface { GetOAuthStatus(ctx context.Context, userID, pluginID int64) (resp *dto.GetOAuthStatusResponse, err error) GetAgentPluginsOAuthStatus(ctx context.Context, userID, agentID int64) (status []*dto.AgentPluginOAuthStatus, err error) - OAuthCode(ctx context.Context, code string, state *entity.OAuthState) (err error) - GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (accessToken string, err error) - RevokeAccessToken(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) + OAuthCode(ctx context.Context, code string, state *dto.OAuthState) (err error) + GetAccessToken(ctx context.Context, oa *dto.OAuthInfo) (accessToken string, err error) + RevokeAccessToken(ctx context.Context, meta *dto.AuthorizationCodeMeta) (err error) } diff --git a/backend/domain/plugin/service/tool/invocation_args.go b/backend/domain/plugin/service/tool/invocation_args.go index eb8802393..445795da8 100644 --- a/backend/domain/plugin/service/tool/invocation_args.go +++ b/backend/domain/plugin/service/tool/invocation_args.go @@ -29,7 +29,9 @@ import ( "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables" "github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory" api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" @@ -65,7 +67,7 @@ type InvocationArgs struct { ServerURL string UserID string - ProjectInfo *entity.ProjectInfo + ProjectInfo *model.ProjectInfo AccessToken string AuthURL string @@ -78,7 +80,7 @@ type InvocationArgs struct { type InvocationArgsBuilder struct { ArgsInJson string - ProjectInfo *entity.ProjectInfo + ProjectInfo *model.ProjectInfo UserID string AccessToken string AuthURL string @@ -221,7 +223,7 @@ func (i *InvocationArgs) groupedRequestArgs(ctx context.Context, args map[string i.Body = bodyArgs } -func (i *InvocationArgs) setCommonParams(ctx context.Context, commonParams map[model.HTTPParamLocation][]*api.CommonParamSchema) { +func (i *InvocationArgs) setCommonParams(ctx context.Context, commonParams map[consts.HTTPParamLocation][]*api.CommonParamSchema) { for location, params := range commonParams { for _, param := range params { if param.Name == "" { @@ -230,13 +232,13 @@ func (i *InvocationArgs) setCommonParams(ctx context.Context, commonParams map[m var dic map[string]any switch location { - case model.ParamInHeader: + case consts.ParamInHeader: dic = i.Header - case model.ParamInPath: + case consts.ParamInPath: dic = i.Path - case model.ParamInQuery: + case consts.ParamInQuery: dic = i.Query - case model.ParamInBody: + case consts.ParamInBody: dic = i.Body default: logs.CtxWarnf(ctx, "unsupported common parameter location '%s' in api schema, name=%s", location, param.Name) @@ -250,7 +252,7 @@ func (i *InvocationArgs) setCommonParams(ctx context.Context, commonParams map[m } } -func (i *InvocationArgs) setDefaultValues(ctx context.Context, projectInfo *entity.ProjectInfo, userID string) (err error) { +func (i *InvocationArgs) setDefaultValues(ctx context.Context, projectInfo *model.ProjectInfo, userID string) (err error) { groupedKeysSchema := i.groupedKeySchema i.Header, err = setParameterDefaultValues(ctx, i.Header, groupedKeysSchema.HeaderKeys, projectInfo, userID) @@ -282,7 +284,7 @@ func (i *InvocationArgs) setDefaultValues(ctx context.Context, projectInfo *enti return nil } -func setParameterDefaultValues(ctx context.Context, dic map[string]any, paramSchema map[string]*openapi3.Parameter, projectInfo *entity.ProjectInfo, userID string) (map[string]any, error) { +func setParameterDefaultValues(ctx context.Context, dic map[string]any, paramSchema map[string]*openapi3.Parameter, projectInfo *model.ProjectInfo, userID string) (map[string]any, error) { for key, valueSchema := range paramSchema { if valueSchema.Schema == nil || valueSchema.Schema.Value == nil { logs.CtxWarnf(ctx, "[setParameterDefaultValues] parameter '%s' schema is nil", key) @@ -311,7 +313,7 @@ func setParameterDefaultValues(ctx context.Context, dic map[string]any, paramSch return dic, nil } -func setBodyDefaultValues(ctx context.Context, dic map[string]any, sc *openapi3.Schema, projectInfo *entity.ProjectInfo, userID string) (map[string]any, error) { +func setBodyDefaultValues(ctx context.Context, dic map[string]any, sc *openapi3.Schema, projectInfo *model.ProjectInfo, userID string) (map[string]any, error) { required := slices.ToMap(sc.Required, func(e string) (string, bool) { return e, true }) @@ -367,8 +369,8 @@ func setBodyDefaultValues(ctx context.Context, dic map[string]any, sc *openapi3. return newVals, nil } -func getDefaultValue(ctx context.Context, schema *openapi3.Schema, info *entity.ProjectInfo, userID string) (any, error) { - vn, exist := schema.Extensions[model.APISchemaExtendVariableRef] +func getDefaultValue(ctx context.Context, schema *openapi3.Schema, info *model.ProjectInfo, userID string) (any, error) { + vn, exist := schema.Extensions[consts.APISchemaExtendVariableRef] if !exist { return schema.Default, nil } @@ -491,7 +493,7 @@ func isFileSchema(valueSchema *openapi3.Schema) bool { } // file schema x-assist-type must not nil - assistTypeObj := valueSchema.Extensions[model.APISchemaExtendAssistType] + assistTypeObj := valueSchema.Extensions[consts.APISchemaExtendAssistType] if assistTypeObj == nil { // it is not a file value return false @@ -502,7 +504,7 @@ func isFileSchema(valueSchema *openapi3.Schema) bool { return false } - if !model.IsValidAPIAssistType(model.APIFileAssistType(assistType)) { + if !convert.IsValidAPIAssistType(consts.APIFileAssistType(assistType)) { return false } diff --git a/backend/domain/plugin/service/tool/invocation_http.go b/backend/domain/plugin/service/tool/invocation_http.go index a843bb2d0..14528731b 100644 --- a/backend/domain/plugin/service/tool/invocation_http.go +++ b/backend/domain/plugin/service/tool/invocation_http.go @@ -31,13 +31,15 @@ import ( "github.com/go-resty/resty/v2" "github.com/tidwall/sjson" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + pluginConsts "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/encoder" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/i18n" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" + "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -66,7 +68,7 @@ func (h *httpCallImpl) Do(ctx context.Context, args *InvocationArgs) (request st if errMsg != "" { event := &model.ToolInterruptEvent{ - Event: model.InterruptEventTypeOfToolNeedOAuth, + Event: pluginConsts.InterruptEventTypeOfToolNeedOAuth, ToolNeedOAuth: &model.ToolNeedOAuthInterruptEvent{ Message: errMsg, }, @@ -154,15 +156,15 @@ func (h *httpCallImpl) buildHTTPRequest(ctx context.Context, args *InvocationArg func (h *httpCallImpl) injectAuthInfo(ctx context.Context, httpReq *http.Request, args *InvocationArgs) (errMsg string, err error) { - if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfNone { + if args.AuthInfo.MetaInfo.Type == pluginConsts.AuthzTypeOfNone { return "", nil } - if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfService { + if args.AuthInfo.MetaInfo.Type == pluginConsts.AuthzTypeOfService { return h.injectServiceAPIToken(ctx, httpReq, args.AuthInfo.MetaInfo) } - if args.AuthInfo.MetaInfo.Type == model.AuthzTypeOfOAuth { + if args.AuthInfo.MetaInfo.Type == pluginConsts.AuthzTypeOfOAuth { return h.injectOAuthAccessToken(ctx, httpReq, args) } @@ -278,7 +280,7 @@ func (h *httpCallImpl) buildRequestBody(ctx context.Context, op *model.Openapi3O } func (h *httpCallImpl) injectServiceAPIToken(ctx context.Context, httpReq *http.Request, authInfo *model.AuthV2) (errMsg string, err error) { - if authInfo.SubType == model.AuthzSubTypeOfServiceAPIToken { + if authInfo.SubType == pluginConsts.AuthzSubTypeOfServiceAPIToken { authOfAPIToken := authInfo.AuthOfAPIToken if authOfAPIToken == nil { return "", fmt.Errorf("auth of api token is nil") @@ -304,20 +306,20 @@ func (h *httpCallImpl) injectServiceAPIToken(ctx context.Context, httpReq *http. } func (h *httpCallImpl) injectOAuthAccessToken(ctx context.Context, httpReq *http.Request, args *InvocationArgs) (errMsg string, err error) { - authMode := model.ToolAuthModeOfRequired - if tmp, ok := args.Tool.Operation.Extensions[model.APISchemaExtendAuthMode].(string); ok { - authMode = model.ToolAuthMode(tmp) + authMode := pluginConsts.ToolAuthModeOfRequired + if tmp, ok := args.Tool.Operation.Extensions[pluginConsts.APISchemaExtendAuthMode].(string); ok { + authMode = pluginConsts.ToolAuthMode(tmp) } - if authMode == model.ToolAuthModeOfDisabled { + if authMode == pluginConsts.ToolAuthModeOfDisabled { return "", nil } accessToken := args.AccessToken authInfo := args.AuthInfo.MetaInfo - if authInfo.SubType == model.AuthzSubTypeOfOAuthAuthorizationCode && - accessToken == "" && authMode != model.ToolAuthModeOfSupported { + if authInfo.SubType == pluginConsts.AuthzSubTypeOfOAuthAuthorizationCode && + accessToken == "" && authMode != pluginConsts.ToolAuthModeOfSupported { errMsg = authCodeInvalidTokenErrMsg[i18n.GetLocale(ctx)] if errMsg == "" { errMsg = authCodeInvalidTokenErrMsg[i18n.LocaleEN] diff --git a/backend/domain/workflow/entity/vo/workflow_copy.go b/backend/domain/workflow/entity/vo/workflow_copy.go index 4f9688888..c56e73d88 100644 --- a/backend/domain/workflow/entity/vo/workflow_copy.go +++ b/backend/domain/workflow/entity/vo/workflow_copy.go @@ -16,12 +16,13 @@ package vo -import ( - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" -) +type PluginEntity struct { + PluginID int64 + PluginVersion *string // nil or "0" means draft, "" means latest/online version, otherwise is specific version +} type ExternalResourceRelated struct { - PluginMap map[int64]*plugin.PluginEntity + PluginMap map[int64]*PluginEntity PluginToolMap map[int64]int64 KnowledgeMap map[int64]int64 diff --git a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go index 19f0e47d3..569bb1978 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go +++ b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go @@ -51,6 +51,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin" "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" mockWorkflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow" mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code" @@ -787,8 +788,7 @@ func TestCodeAndPluginNodes(t *testing.T) { mockToolService := pluginmock.NewMockPluginService(ctrl) mockey.Mock(crossplugin.DefaultSVC).Return(mockToolService).Build() - mockToolService.EXPECT().ExecutePlugin(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), - gomock.Any()).Return(map[string]any{ + mockey.Mock(plugin.ExecutePlugin).Return(map[string]any{ "log_id": "20240617191637796DF3F4453E16AF3615", "msg": "success", "code": 0, @@ -796,7 +796,7 @@ func TestCodeAndPluginNodes(t *testing.T) { "image_url": "image_url", "prompt": "小狗在草地上", }, - }, nil).AnyTimes() + }, nil).Build() ctx := t.Context() ctx = ctxcache.Init(ctx) diff --git a/backend/domain/workflow/internal/compose/designate_option.go b/backend/domain/workflow/internal/compose/designate_option.go index 5845ad8c3..aa9ca5087 100644 --- a/backend/domain/workflow/internal/compose/designate_option.go +++ b/backend/domain/workflow/internal/compose/designate_option.go @@ -25,8 +25,6 @@ import ( einoCompose "github.com/cloudwego/eino/compose" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" - crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" @@ -35,6 +33,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + wrapPlugin "github.com/coze-dev/coze-studio/backend/domain/workflow/plugin" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) @@ -297,8 +296,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh return nil, err } - toolInfoResponse, err := crossplugin.DefaultSVC().GetPluginToolsInfo(ctx, &plugin.ToolsInfoRequest{ - PluginEntity: plugin.PluginEntity{ + toolInfoResponse, err := wrapPlugin.GetPluginToolsInfo(ctx, &wrapPlugin.ToolsInfoRequest{ + PluginEntity: vo.PluginEntity{ PluginID: pluginID, PluginVersion: ptr.Of(p.PluginVersion), }, diff --git a/backend/domain/workflow/internal/nodes/llm/llm.go b/backend/domain/workflow/internal/nodes/llm/llm.go index f314ea7d1..5c9b86dbb 100644 --- a/backend/domain/workflow/internal/nodes/llm/llm.go +++ b/backend/domain/workflow/internal/nodes/llm/llm.go @@ -41,8 +41,6 @@ import ( crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge" crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message" crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr" - crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" @@ -50,6 +48,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + wrapPlugin "github.com/coze-dev/coze-studio/backend/domain/workflow/plugin" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -461,7 +460,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2 } if fcParams.PluginFCParam != nil { - pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest) + pluginToolsInvokableReq := make(map[int64]*wrapPlugin.ToolsInvokableRequest) for _, p := range fcParams.PluginFCParam.PluginList { pid, err := strconv.ParseInt(p.PluginID, 10, 64) if err != nil { @@ -482,18 +481,18 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2 } if req, ok := pluginToolsInvokableReq[pid]; ok { - req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{ + req.ToolsInvokableInfo[toolID] = &wrapPlugin.ToolsInvokableInfo{ ToolID: toolID, RequestAPIParametersConfig: requestParameters, ResponseAPIParametersConfig: responseParameters, } } else { - pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{ - PluginEntity: plugin.PluginEntity{ + pluginToolsInfoRequest := &wrapPlugin.ToolsInvokableRequest{ + PluginEntity: vo.PluginEntity{ PluginID: pid, PluginVersion: ptr.Of(p.PluginVersion), }, - ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{ + ToolsInvokableInfo: map[int64]*wrapPlugin.ToolsInvokableInfo{ toolID: { ToolID: toolID, RequestAPIParametersConfig: requestParameters, @@ -507,7 +506,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2 } inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList)) for _, req := range pluginToolsInvokableReq { - toolMap, err := crossplugin.DefaultSVC().GetPluginInvokableTools(ctx, req) + toolMap, err := wrapPlugin.GetPluginInvokableTools(ctx, req) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/nodes/llm/prompt.go b/backend/domain/workflow/internal/nodes/llm/prompt.go index ad43cdf1c..61ffd7961 100644 --- a/backend/domain/workflow/internal/nodes/llm/prompt.go +++ b/backend/domain/workflow/internal/nodes/llm/prompt.go @@ -149,17 +149,17 @@ func enableLocalFileToLLMWithBase64(minfo *modelmgr.Model) bool { } func getModelProcessingInfo(ctx context.Context, mwi ModelWithInfo) (map[modelmgr.Modal]bool, bool) { - mInfo := mwi.Info(ctx) + mInfo := mwi.Info(ctx) - supportedModal := make(map[modelmgr.Modal]bool) - if mInfo != nil { - for i := range mInfo.Meta.Capability.InputModal { - supportedModal[mInfo.Meta.Capability.InputModal[i]] = true - } - } + supportedModal := make(map[modelmgr.Modal]bool) + if mInfo != nil { + for i := range mInfo.Meta.Capability.InputModal { + supportedModal[mInfo.Meta.Capability.InputModal[i]] = true + } + } - enableTransferBase64 := enableLocalFileToLLMWithBase64(mInfo) - return supportedModal, enableTransferBase64 + enableTransferBase64 := enableLocalFileToLLMWithBase64(mInfo) + return supportedModal, enableTransferBase64 } func (pl *promptTpl) render(ctx context.Context, vs map[string]any, diff --git a/backend/domain/workflow/internal/nodes/llm/prompt_test.go b/backend/domain/workflow/internal/nodes/llm/prompt_test.go index 4931ddc98..8229d0505 100644 --- a/backend/domain/workflow/internal/nodes/llm/prompt_test.go +++ b/backend/domain/workflow/internal/nodes/llm/prompt_test.go @@ -21,9 +21,10 @@ import ( "github.com/bytedance/mockey" "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/urltobase64url" - "github.com/stretchr/testify/assert" ) func TestTransformMessagePart(t *testing.T) { diff --git a/backend/domain/workflow/internal/nodes/plugin/exec.go b/backend/domain/workflow/internal/nodes/plugin/exec.go new file mode 100644 index 000000000..837868e3b --- /dev/null +++ b/backend/domain/workflow/internal/nodes/plugin/exec.go @@ -0,0 +1,110 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package plugin + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" + + workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" + workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow" + crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/workflow" + entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" + "github.com/coze-dev/coze-studio/backend/pkg/sonic" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +func ExecutePlugin(ctx context.Context, input map[string]any, pe *vo.PluginEntity, + toolID int64, cfg workflowModel.ExecuteConfig) (map[string]any, error) { + args, err := sonic.MarshalString(input) + if err != nil { + return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err) + } + + var uID string + if cfg.AgentID != nil { + uID = cfg.ConnectorUID + } else { + uID = conv.Int64ToStr(cfg.Operator) + } + + req := &model.ExecuteToolRequest{ + UserID: uID, + PluginID: pe.PluginID, + ToolID: toolID, + ExecScene: consts.ExecSceneOfWorkflow, + ArgumentsInJson: args, + ExecDraftTool: pe.PluginVersion == nil || *pe.PluginVersion == "0", + } + execOpts := []model.ExecuteToolOpt{ + model.WithInvalidRespProcessStrategy(consts.InvalidResponseProcessStrategyOfReturnDefault), + } + + if pe.PluginVersion != nil { + execOpts = append(execOpts, model.WithToolVersion(*pe.PluginVersion)) + } + + r, err := crossplugin.DefaultSVC().ExecuteTool(ctx, req, execOpts...) + if err != nil { + if extra, ok := compose.IsInterruptRerunError(err); ok { + pluginTIE, ok := extra.(*model.ToolInterruptEvent) + if !ok { + return nil, vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra)) + } + + var eventType workflow3.EventType + switch pluginTIE.Event { + case consts.InterruptEventTypeOfToolNeedOAuth: + eventType = workflow3.EventType_WorkflowOauthPlugin + default: + return nil, vo.WrapError(errno.ErrPluginAPIErr, + fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event)) + } + + id, err := workflow.GetRepository().GenID(ctx) + if err != nil { + return nil, vo.WrapError(errno.ErrIDGenError, err) + } + + ie := &entity2.InterruptEvent{ + ID: id, + InterruptData: pluginTIE.ToolNeedOAuth.Message, + EventType: eventType, + } + + // temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt + interruptData := ie.InterruptData + return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData)) + } + return nil, err + } + + var output map[string]any + err = sonic.UnmarshalString(r.TrimmedResp, &output) + if err != nil { + return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err) + } + + return output, nil +} diff --git a/backend/domain/workflow/internal/nodes/plugin/plugin.go b/backend/domain/workflow/internal/nodes/plugin/plugin.go index ffa37cc16..cfa5025ac 100644 --- a/backend/domain/workflow/internal/nodes/plugin/plugin.go +++ b/backend/domain/workflow/internal/nodes/plugin/plugin.go @@ -24,8 +24,6 @@ import ( "github.com/cloudwego/eino/compose" workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" - crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" - model "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" @@ -116,7 +114,7 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil { exeCfg = ctxExeCfg.ExeCfg } - result, err := crossplugin.DefaultSVC().ExecutePlugin(ctx, parameters, &model.PluginEntity{ + result, err := ExecutePlugin(ctx, parameters, &vo.PluginEntity{ PluginID: p.pluginID, PluginVersion: ptr.Of(p.pluginVersion), }, p.toolID, exeCfg) diff --git a/backend/crossdomain/contract/plugin/dto/workflow.go b/backend/domain/workflow/plugin/model.go similarity index 87% rename from backend/crossdomain/contract/plugin/dto/workflow.go rename to backend/domain/workflow/plugin/model.go index 11020c0e7..969fc5495 100644 --- a/backend/crossdomain/contract/plugin/dto/workflow.go +++ b/backend/domain/workflow/plugin/model.go @@ -14,23 +14,19 @@ * limitations under the License. */ -package dto +package plugin import ( "github.com/coze-dev/coze-studio/backend/api/model/workflow" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" ) type ToolsInfoRequest struct { - PluginEntity PluginEntity + PluginEntity vo.PluginEntity ToolIDs []int64 IsDraft bool } -type PluginEntity struct { - PluginID int64 - PluginVersion *string // nil or "0" means draft, "" means latest/online version, otherwise is specific version -} - type ToolsInfoResponse struct { PluginID int64 SpaceID int64 @@ -61,7 +57,7 @@ type DebugExample struct { } type ToolsInvokableRequest struct { - PluginEntity PluginEntity + PluginEntity vo.PluginEntity ToolsInvokableInfo map[int64]*ToolsInvokableInfo IsDraft bool } diff --git a/backend/domain/workflow/plugin/plugin.go b/backend/domain/workflow/plugin/plugin.go new file mode 100644 index 000000000..6f78afe38 --- /dev/null +++ b/backend/domain/workflow/plugin/plugin.go @@ -0,0 +1,452 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package plugin + +import ( + "context" + "fmt" + "strconv" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "github.com/getkin/kin-openapi/openapi3" + "golang.org/x/exp/maps" + + workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" + "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" + workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow" + crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/consts" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/convert/api" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow" + entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/infra/contract/storage" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +var oss storage.Storage + +func SetOSS(s storage.Storage) { + oss = s +} + +type pluginInfo struct { + *model.PluginInfo + LatestVersion *string +} + +func getPluginsWithTools(ctx context.Context, pluginEntity *vo.PluginEntity, toolIDs []int64, isDraft bool) ( + _ *pluginInfo, toolsInfo []*entity.ToolInfo, err error) { + defer func() { + if err != nil { + err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) + } + }() + + var pluginsInfo []*model.PluginInfo + var latestPluginInfo *model.PluginInfo + pluginID := pluginEntity.PluginID + if isDraft { + plugins, err := crossplugin.DefaultSVC().MGetDraftPlugins(ctx, []int64{pluginID}) + if err != nil { + return nil, nil, err + } + pluginsInfo = plugins + } else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") { + plugins, err := crossplugin.DefaultSVC().MGetOnlinePlugins(ctx, []int64{pluginID}) + if err != nil { + return nil, nil, err + } + pluginsInfo = plugins + + } else { + plugins, err := crossplugin.DefaultSVC().MGetVersionPlugins(ctx, []model.VersionPlugin{ + {PluginID: pluginID, Version: *pluginEntity.PluginVersion}, + }) + if err != nil { + return nil, nil, err + } + pluginsInfo = plugins + + onlinePlugins, err := crossplugin.DefaultSVC().MGetOnlinePlugins(ctx, []int64{pluginID}) + if err != nil { + return nil, nil, err + } + for _, pi := range onlinePlugins { + if pi.ID == pluginID { + latestPluginInfo = pi + break + } + } + } + + var pInfo *model.PluginInfo + for _, p := range pluginsInfo { + if p.ID == pluginID { + pInfo = p + break + } + } + if pInfo == nil { + return nil, nil, vo.NewError(errno.ErrPluginIDNotFound, errorx.KV("id", strconv.FormatInt(pluginID, 10))) + } + + if isDraft { + tools, err := crossplugin.DefaultSVC().MGetDraftTools(ctx, toolIDs) + if err != nil { + return nil, nil, err + } + toolsInfo = tools + } else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") { + tools, err := crossplugin.DefaultSVC().MGetOnlineTools(ctx, toolIDs) + if err != nil { + return nil, nil, err + } + toolsInfo = tools + } else { + eVersionTools := slices.Transform(toolIDs, func(tid int64) model.VersionTool { + return model.VersionTool{ + ToolID: tid, + Version: *pluginEntity.PluginVersion, + } + }) + tools, err := crossplugin.DefaultSVC().MGetVersionTools(ctx, eVersionTools) + if err != nil { + return nil, nil, err + } + toolsInfo = tools + } + + if latestPluginInfo != nil { + return &pluginInfo{PluginInfo: pInfo, LatestVersion: latestPluginInfo.Version}, toolsInfo, nil + } + + return &pluginInfo{PluginInfo: pInfo}, toolsInfo, nil +} + +func GetPluginToolsInfo(ctx context.Context, req *ToolsInfoRequest) (_ *ToolsInfoResponse, err error) { + defer func() { + if err != nil { + err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) + } + }() + + var toolsInfo []*entity.ToolInfo + isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0") + pInfo, toolsInfo, err := getPluginsWithTools(ctx, &vo.PluginEntity{PluginID: req.PluginEntity.PluginID, PluginVersion: req.PluginEntity.PluginVersion}, req.ToolIDs, isDraft) + if err != nil { + return nil, err + } + + if oss == nil { + return nil, vo.NewError(errno.ErrTOSError, errorx.KV("msg", "oss is nil")) + } + + url, err := oss.GetObjectUrl(ctx, pInfo.GetIconURI()) + if err != nil { + return nil, vo.WrapIfNeeded(errno.ErrTOSError, err) + } + + response := &ToolsInfoResponse{ + PluginID: pInfo.ID, + SpaceID: pInfo.SpaceID, + Version: pInfo.GetVersion(), + PluginName: pInfo.GetName(), + Description: pInfo.GetDesc(), + IconURL: url, + PluginType: int64(pInfo.PluginType), + ToolInfoList: make(map[int64]ToolInfoW), + LatestVersion: pInfo.LatestVersion, + IsOfficial: pInfo.IsOfficial(), + AppID: pInfo.GetAPPID(), + } + + for _, tf := range toolsInfo { + inputs, err := tf.ToReqAPIParameter() + if err != nil { + return nil, err + } + outputs, err := tf.ToRespAPIParameter() + if err != nil { + return nil, err + } + toolExample := pInfo.GetToolExample(ctx, tf.GetName()) + + var ( + requestExample string + responseExample string + ) + if toolExample != nil { + requestExample = toolExample.RequestExample + responseExample = toolExample.ResponseExample + } + + response.ToolInfoList[tf.ID] = ToolInfoW{ + ToolID: tf.ID, + ToolName: tf.GetName(), + Inputs: slices.Transform(inputs, toWorkflowAPIParameter), + Outputs: slices.Transform(outputs, toWorkflowAPIParameter), + Description: tf.GetDesc(), + DebugExample: &DebugExample{ + ReqExample: requestExample, + RespExample: responseExample, + }, + } + + } + return response, nil +} + +func GetPluginInvokableTools(ctx context.Context, req *ToolsInvokableRequest) ( + _ map[int64]crossplugin.InvokableTool, err error) { + defer func() { + if err != nil { + err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) + } + }() + + var toolsInfo []*entity.ToolInfo + isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0") + pInfo, toolsInfo, err := getPluginsWithTools(ctx, &vo.PluginEntity{ + PluginID: req.PluginEntity.PluginID, + PluginVersion: req.PluginEntity.PluginVersion, + }, maps.Keys(req.ToolsInvokableInfo), isDraft) + if err != nil { + return nil, err + } + + result := map[int64]crossplugin.InvokableTool{} + for _, tf := range toolsInfo { + tl := &pluginInvokeTool{ + pluginEntity: vo.PluginEntity{ + PluginID: pInfo.ID, + PluginVersion: pInfo.Version, + }, + toolInfo: tf, + IsDraft: isDraft, + } + + if r, ok := req.ToolsInvokableInfo[tf.ID]; ok && (r.RequestAPIParametersConfig != nil && r.ResponseAPIParametersConfig != nil) { + reqPluginCommonAPIParameters := slices.Transform(r.RequestAPIParametersConfig, toPluginCommonAPIParameter) + respPluginCommonAPIParameters := slices.Transform(r.ResponseAPIParametersConfig, toPluginCommonAPIParameter) + + tl.toolOperation, err = api.APIParamsToOpenapiOperation(reqPluginCommonAPIParameters, respPluginCommonAPIParameters) + if err != nil { + return nil, err + } + + tl.toolOperation.OperationID = tf.Operation.OperationID + tl.toolOperation.Summary = tf.Operation.Summary + } + + result[tf.ID] = tl + } + return result, nil +} + +type pluginInvokeTool struct { + pluginEntity vo.PluginEntity + toolInfo *model.ToolInfo + toolOperation *openapi3.Operation + IsDraft bool +} + +func (p *pluginInvokeTool) Info(ctx context.Context) (_ *schema.ToolInfo, err error) { + defer func() { + if err != nil { + err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err) + } + }() + + var parameterInfo map[string]*schema.ParameterInfo + if p.toolOperation != nil { + parameterInfo, err = model.NewOpenapi3Operation(p.toolOperation).ToEinoSchemaParameterInfo(ctx) + } else { + parameterInfo, err = p.toolInfo.Operation.ToEinoSchemaParameterInfo(ctx) + } + + if err != nil { + return nil, err + } + + return &schema.ToolInfo{ + Name: p.toolInfo.GetName(), + Desc: p.toolInfo.GetDesc(), + ParamsOneOf: schema.NewParamsOneOfByParams(parameterInfo), + }, nil +} + +func (p *pluginInvokeTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflowModel.ExecuteConfig) (string, error) { + req := &model.ExecuteToolRequest{ + UserID: conv.Int64ToStr(cfg.Operator), + PluginID: p.pluginEntity.PluginID, + ToolID: p.toolInfo.ID, + ExecScene: consts.ExecSceneOfWorkflow, + ArgumentsInJson: argumentsInJSON, + ExecDraftTool: p.IsDraft, + } + execOpts := []model.ExecuteToolOpt{ + model.WithInvalidRespProcessStrategy(consts.InvalidResponseProcessStrategyOfReturnDefault), + } + + if p.pluginEntity.PluginVersion != nil { + execOpts = append(execOpts, model.WithToolVersion(*p.pluginEntity.PluginVersion)) + } + + if p.toolOperation != nil { + execOpts = append(execOpts, model.WithOpenapiOperation(model.NewOpenapi3Operation(p.toolOperation))) + } + + r, err := crossplugin.DefaultSVC().ExecuteTool(ctx, req, execOpts...) + if err != nil { + if extra, ok := compose.IsInterruptRerunError(err); ok { + pluginTIE, ok := extra.(*model.ToolInterruptEvent) + if !ok { + return "", vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra)) + } + + var eventType workflow3.EventType + switch pluginTIE.Event { + case consts.InterruptEventTypeOfToolNeedOAuth: + eventType = workflow3.EventType_WorkflowOauthPlugin + default: + return "", vo.WrapError(errno.ErrPluginAPIErr, + fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event)) + } + + id, eErr := workflow.GetRepository().GenID(ctx) + if eErr != nil { + return "", vo.WrapError(errno.ErrIDGenError, eErr) + } + + ie := &entity2.InterruptEvent{ + ID: id, + InterruptData: pluginTIE.ToolNeedOAuth.Message, + EventType: eventType, + } + + tie := &entity2.ToolInterruptEvent{ + ToolCallID: compose.GetToolCallID(ctx), + ToolName: p.toolInfo.GetName(), + InterruptEvent: ie, + } + + // temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt + _ = tie + interruptData := ie.InterruptData + return "", vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData)) + } + return "", err + } + return r.TrimmedResp, nil +} + +func toPluginCommonAPIParameter(parameter *workflow3.APIParameter) *common.APIParameter { + if parameter == nil { + return nil + } + p := &common.APIParameter{ + ID: parameter.ID, + Name: parameter.Name, + Desc: parameter.Desc, + Type: common.ParameterType(parameter.Type), + Location: common.ParameterLocation(parameter.Location), + IsRequired: parameter.IsRequired, + GlobalDefault: parameter.GlobalDefault, + GlobalDisable: parameter.GlobalDisable, + LocalDefault: parameter.LocalDefault, + LocalDisable: parameter.LocalDisable, + VariableRef: parameter.VariableRef, + } + if parameter.SubType != nil { + p.SubType = ptr.Of(common.ParameterType(*parameter.SubType)) + } + + if parameter.DefaultParamSource != nil { + p.DefaultParamSource = ptr.Of(common.DefaultParamSource(*parameter.DefaultParamSource)) + } + if parameter.AssistType != nil { + p.AssistType = ptr.Of(common.AssistParameterType(*parameter.AssistType)) + } + + if len(parameter.SubParameters) > 0 { + p.SubParameters = make([]*common.APIParameter, 0, len(parameter.SubParameters)) + for _, subParam := range parameter.SubParameters { + p.SubParameters = append(p.SubParameters, toPluginCommonAPIParameter(subParam)) + } + } + + return p +} + +func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParameter { + if parameter == nil { + return nil + } + p := &workflow3.APIParameter{ + ID: parameter.ID, + Name: parameter.Name, + Desc: parameter.Desc, + Type: workflow3.ParameterType(parameter.Type), + Location: workflow3.ParameterLocation(parameter.Location), + IsRequired: parameter.IsRequired, + GlobalDefault: parameter.GlobalDefault, + GlobalDisable: parameter.GlobalDisable, + LocalDefault: parameter.LocalDefault, + LocalDisable: parameter.LocalDisable, + VariableRef: parameter.VariableRef, + } + if parameter.SubType != nil { + p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType)) + } + if parameter.DefaultParamSource != nil { + p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource)) + } + if parameter.AssistType != nil { + p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType)) + } + + // Check if it's a specially wrapped array that needs unwrapping. + if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" { + arrayItem := parameter.SubParameters[0] + // The actual type of array elements is the type of the "[Array Item]". + p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type)) + // If the array elements are objects, their sub-parameters (fields) are lifted up. + if arrayItem.Type == common.ParameterType_Object { + p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters)) + for _, subParam := range arrayItem.SubParameters { + p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam)) + } + } else { + p.SubParameters = make([]*workflow3.APIParameter, 0, 1) + p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem)) + } + } else if len(parameter.SubParameters) > 0 { + p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters)) + for _, subParam := range parameter.SubParameters { + p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam)) + } + } + + return p +} diff --git a/backend/crossdomain/impl/plugin/plugin_test.go b/backend/domain/workflow/plugin/plugin_test.go similarity index 100% rename from backend/crossdomain/impl/plugin/plugin_test.go rename to backend/domain/workflow/plugin/plugin_test.go diff --git a/backend/domain/workflow/service/service_impl.go b/backend/domain/workflow/service/service_impl.go index 8b2d5b107..5522e37f1 100644 --- a/backend/domain/workflow/service/service_impl.go +++ b/backend/domain/workflow/service/service_impl.go @@ -20,20 +20,17 @@ import ( "context" "errors" "fmt" + "strconv" + einoCompose "github.com/cloudwego/eino/compose" "github.com/spf13/cast" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" "gorm.io/gorm" - "strconv" - - einoCompose "github.com/cloudwego/eino/compose" - workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/workflow" "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" @@ -822,7 +819,7 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con return nil, err } - relatedPlugins := make(map[int64]*plugin.PluginEntity, len(config.PluginIDs)) + relatedPlugins := make(map[int64]*vo.PluginEntity, len(config.PluginIDs)) relatedWorkflow := make(map[int64]entity.IDVersionPair, len(allWorkflowsInApp)) for _, wf := range allWorkflowsInApp { @@ -833,7 +830,7 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con } for _, id := range config.PluginIDs { - relatedPlugins[id] = &plugin.PluginEntity{ + relatedPlugins[id] = &vo.PluginEntity{ PluginID: id, PluginVersion: &config.Version, } diff --git a/backend/internal/mock/domain/plugin/interface.go b/backend/internal/mock/domain/plugin/interface.go index 6876b8c4e..ca257666c 100644 --- a/backend/internal/mock/domain/plugin/interface.go +++ b/backend/internal/mock/domain/plugin/interface.go @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" - dto "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" + dto0 "github.com/coze-dev/coze-studio/backend/domain/plugin/dto" entity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" gomock "go.uber.org/mock/gomock" ) @@ -72,10 +72,10 @@ func (mr *MockPluginServiceMockRecorder) CheckPluginToolsDebugStatus(ctx, plugin } // ConvertToOpenapi3Doc mocks base method. -func (m *MockPluginService) ConvertToOpenapi3Doc(ctx context.Context, req *dto.ConvertToOpenapi3DocRequest) *dto.ConvertToOpenapi3DocResponse { +func (m *MockPluginService) ConvertToOpenapi3Doc(ctx context.Context, req *dto0.ConvertToOpenapi3DocRequest) *dto0.ConvertToOpenapi3DocResponse { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConvertToOpenapi3Doc", ctx, req) - ret0, _ := ret[0].(*dto.ConvertToOpenapi3DocResponse) + ret0, _ := ret[0].(*dto0.ConvertToOpenapi3DocResponse) return ret0 } @@ -86,10 +86,10 @@ func (mr *MockPluginServiceMockRecorder) ConvertToOpenapi3Doc(ctx, req any) *gom } // CopyPlugin mocks base method. -func (m *MockPluginService) CopyPlugin(ctx context.Context, req *dto.CopyPluginRequest) (*dto.CopyPluginResponse, error) { +func (m *MockPluginService) CopyPlugin(ctx context.Context, req *dto0.CopyPluginRequest) (*dto0.CopyPluginResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CopyPlugin", ctx, req) - ret0, _ := ret[0].(*dto.CopyPluginResponse) + ret0, _ := ret[0].(*dto0.CopyPluginResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -101,7 +101,7 @@ func (mr *MockPluginServiceMockRecorder) CopyPlugin(ctx, req any) *gomock.Call { } // CreateDraftPlugin mocks base method. -func (m *MockPluginService) CreateDraftPlugin(ctx context.Context, req *dto.CreateDraftPluginRequest) (int64, error) { +func (m *MockPluginService) CreateDraftPlugin(ctx context.Context, req *dto0.CreateDraftPluginRequest) (int64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateDraftPlugin", ctx, req) ret0, _ := ret[0].(int64) @@ -116,10 +116,10 @@ func (mr *MockPluginServiceMockRecorder) CreateDraftPlugin(ctx, req any) *gomock } // CreateDraftPluginWithCode mocks base method. -func (m *MockPluginService) CreateDraftPluginWithCode(ctx context.Context, req *dto.CreateDraftPluginWithCodeRequest) (*dto.CreateDraftPluginWithCodeResponse, error) { +func (m *MockPluginService) CreateDraftPluginWithCode(ctx context.Context, req *dto0.CreateDraftPluginWithCodeRequest) (*dto0.CreateDraftPluginWithCodeResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateDraftPluginWithCode", ctx, req) - ret0, _ := ret[0].(*dto.CreateDraftPluginWithCodeResponse) + ret0, _ := ret[0].(*dto0.CreateDraftPluginWithCodeResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -131,10 +131,10 @@ func (mr *MockPluginServiceMockRecorder) CreateDraftPluginWithCode(ctx, req any) } // CreateDraftToolsWithCode mocks base method. -func (m *MockPluginService) CreateDraftToolsWithCode(ctx context.Context, req *dto.CreateDraftToolsWithCodeRequest) (*dto.CreateDraftToolsWithCodeResponse, error) { +func (m *MockPluginService) CreateDraftToolsWithCode(ctx context.Context, req *dto0.CreateDraftToolsWithCodeRequest) (*dto0.CreateDraftToolsWithCodeResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateDraftToolsWithCode", ctx, req) - ret0, _ := ret[0].(*dto.CreateDraftToolsWithCodeResponse) + ret0, _ := ret[0].(*dto0.CreateDraftToolsWithCodeResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -189,14 +189,14 @@ func (mr *MockPluginServiceMockRecorder) DuplicateDraftAgentTools(ctx, fromAgent } // ExecuteTool mocks base method. -func (m *MockPluginService) ExecuteTool(ctx context.Context, req *plugin.ExecuteToolRequest, opts ...entity.ExecuteToolOpt) (*plugin.ExecuteToolResponse, error) { +func (m *MockPluginService) ExecuteTool(ctx context.Context, req *model.ExecuteToolRequest, opts ...model.ExecuteToolOpt) (*model.ExecuteToolResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ExecuteTool", varargs...) - ret0, _ := ret[0].(*plugin.ExecuteToolResponse) + ret0, _ := ret[0].(*model.ExecuteToolResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -224,7 +224,7 @@ func (mr *MockPluginServiceMockRecorder) GetAPPAllPlugins(ctx, appID any) *gomoc } // GetAccessToken mocks base method. -func (m *MockPluginService) GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (string, error) { +func (m *MockPluginService) GetAccessToken(ctx context.Context, oa *dto0.OAuthInfo) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAccessToken", ctx, oa) ret0, _ := ret[0].(string) @@ -239,10 +239,10 @@ func (mr *MockPluginServiceMockRecorder) GetAccessToken(ctx, oa any) *gomock.Cal } // GetAgentPluginsOAuthStatus mocks base method. -func (m *MockPluginService) GetAgentPluginsOAuthStatus(ctx context.Context, userID, agentID int64) ([]*dto.AgentPluginOAuthStatus, error) { +func (m *MockPluginService) GetAgentPluginsOAuthStatus(ctx context.Context, userID, agentID int64) ([]*dto0.AgentPluginOAuthStatus, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAgentPluginsOAuthStatus", ctx, userID, agentID) - ret0, _ := ret[0].([]*dto.AgentPluginOAuthStatus) + ret0, _ := ret[0].([]*dto0.AgentPluginOAuthStatus) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -284,10 +284,10 @@ func (mr *MockPluginServiceMockRecorder) GetDraftPlugin(ctx, pluginID any) *gomo } // GetOAuthStatus mocks base method. -func (m *MockPluginService) GetOAuthStatus(ctx context.Context, userID, pluginID int64) (*dto.GetOAuthStatusResponse, error) { +func (m *MockPluginService) GetOAuthStatus(ctx context.Context, userID, pluginID int64) (*dto0.GetOAuthStatusResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetOAuthStatus", ctx, userID, pluginID) - ret0, _ := ret[0].(*dto.GetOAuthStatusResponse) + ret0, _ := ret[0].(*dto0.GetOAuthStatusResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -359,7 +359,7 @@ func (mr *MockPluginServiceMockRecorder) GetPluginProductAllTools(ctx, pluginID } // ListCustomOnlinePlugins mocks base method. -func (m *MockPluginService) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) ([]*entity.PluginInfo, int64, error) { +func (m *MockPluginService) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo dto0.PageInfo) ([]*entity.PluginInfo, int64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListCustomOnlinePlugins", ctx, spaceID, pageInfo) ret0, _ := ret[0].([]*entity.PluginInfo) @@ -375,10 +375,10 @@ func (mr *MockPluginServiceMockRecorder) ListCustomOnlinePlugins(ctx, spaceID, p } // ListDraftPlugins mocks base method. -func (m *MockPluginService) ListDraftPlugins(ctx context.Context, req *dto.ListDraftPluginsRequest) (*dto.ListDraftPluginsResponse, error) { +func (m *MockPluginService) ListDraftPlugins(ctx context.Context, req *dto0.ListDraftPluginsRequest) (*dto0.ListDraftPluginsResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListDraftPlugins", ctx, req) - ret0, _ := ret[0].(*dto.ListDraftPluginsResponse) + ret0, _ := ret[0].(*dto0.ListDraftPluginsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -390,10 +390,10 @@ func (mr *MockPluginServiceMockRecorder) ListDraftPlugins(ctx, req any) *gomock. } // ListPluginProducts mocks base method. -func (m *MockPluginService) ListPluginProducts(ctx context.Context, req *dto.ListPluginProductsRequest) (*dto.ListPluginProductsResponse, error) { +func (m *MockPluginService) ListPluginProducts(ctx context.Context, req *dto0.ListPluginProductsRequest) (*dto0.ListPluginProductsResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListPluginProducts", ctx, req) - ret0, _ := ret[0].(*dto.ListPluginProductsResponse) + ret0, _ := ret[0].(*dto0.ListPluginProductsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -405,7 +405,7 @@ func (mr *MockPluginServiceMockRecorder) ListPluginProducts(ctx, req any) *gomoc } // MGetAgentTools mocks base method. -func (m *MockPluginService) MGetAgentTools(ctx context.Context, req *plugin.MGetAgentToolsRequest) ([]*entity.ToolInfo, error) { +func (m *MockPluginService) MGetAgentTools(ctx context.Context, req *model.MGetAgentToolsRequest) ([]*entity.ToolInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetAgentTools", ctx, req) ret0, _ := ret[0].([]*entity.ToolInfo) @@ -480,10 +480,10 @@ func (mr *MockPluginServiceMockRecorder) MGetOnlineTools(ctx, toolIDs any) *gomo } // MGetPluginLatestVersion mocks base method. -func (m *MockPluginService) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (*plugin.MGetPluginLatestVersionResponse, error) { +func (m *MockPluginService) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (*model.MGetPluginLatestVersionResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetPluginLatestVersion", ctx, pluginIDs) - ret0, _ := ret[0].(*plugin.MGetPluginLatestVersionResponse) + ret0, _ := ret[0].(*model.MGetPluginLatestVersionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -495,7 +495,7 @@ func (mr *MockPluginServiceMockRecorder) MGetPluginLatestVersion(ctx, pluginIDs } // MGetVersionPlugins mocks base method. -func (m *MockPluginService) MGetVersionPlugins(ctx context.Context, versionPlugins []entity.VersionPlugin) ([]*entity.PluginInfo, error) { +func (m *MockPluginService) MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) ([]*entity.PluginInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetVersionPlugins", ctx, versionPlugins) ret0, _ := ret[0].([]*entity.PluginInfo) @@ -510,7 +510,7 @@ func (mr *MockPluginServiceMockRecorder) MGetVersionPlugins(ctx, versionPlugins } // MGetVersionTools mocks base method. -func (m *MockPluginService) MGetVersionTools(ctx context.Context, versionTools []entity.VersionTool) ([]*entity.ToolInfo, error) { +func (m *MockPluginService) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) ([]*entity.ToolInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MGetVersionTools", ctx, versionTools) ret0, _ := ret[0].([]*entity.ToolInfo) @@ -540,7 +540,7 @@ func (mr *MockPluginServiceMockRecorder) MoveAPPPluginToLibrary(ctx, pluginID an } // OAuthCode mocks base method. -func (m *MockPluginService) OAuthCode(ctx context.Context, code string, state *entity.OAuthState) error { +func (m *MockPluginService) OAuthCode(ctx context.Context, code string, state *dto0.OAuthState) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OAuthCode", ctx, code, state) ret0, _ := ret[0].(error) @@ -554,10 +554,10 @@ func (mr *MockPluginServiceMockRecorder) OAuthCode(ctx, code, state any) *gomock } // PublishAPPPlugins mocks base method. -func (m *MockPluginService) PublishAPPPlugins(ctx context.Context, req *plugin.PublishAPPPluginsRequest) (*plugin.PublishAPPPluginsResponse, error) { +func (m *MockPluginService) PublishAPPPlugins(ctx context.Context, req *model.PublishAPPPluginsRequest) (*model.PublishAPPPluginsResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PublishAPPPlugins", ctx, req) - ret0, _ := ret[0].(*plugin.PublishAPPPluginsResponse) + ret0, _ := ret[0].(*model.PublishAPPPluginsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -583,7 +583,7 @@ func (mr *MockPluginServiceMockRecorder) PublishAgentTools(ctx, agentID, agentVe } // PublishPlugin mocks base method. -func (m *MockPluginService) PublishPlugin(ctx context.Context, req *plugin.PublishPluginRequest) error { +func (m *MockPluginService) PublishPlugin(ctx context.Context, req *model.PublishPluginRequest) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PublishPlugin", ctx, req) ret0, _ := ret[0].(error) @@ -597,7 +597,7 @@ func (mr *MockPluginServiceMockRecorder) PublishPlugin(ctx, req any) *gomock.Cal } // RevokeAccessToken mocks base method. -func (m *MockPluginService) RevokeAccessToken(ctx context.Context, meta *entity.AuthorizationCodeMeta) error { +func (m *MockPluginService) RevokeAccessToken(ctx context.Context, meta *dto0.AuthorizationCodeMeta) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevokeAccessToken", ctx, meta) ret0, _ := ret[0].(error) @@ -611,7 +611,7 @@ func (mr *MockPluginServiceMockRecorder) RevokeAccessToken(ctx, meta any) *gomoc } // UpdateBotDefaultParams mocks base method. -func (m *MockPluginService) UpdateBotDefaultParams(ctx context.Context, req *dto.UpdateBotDefaultParamsRequest) error { +func (m *MockPluginService) UpdateBotDefaultParams(ctx context.Context, req *dto0.UpdateBotDefaultParamsRequest) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateBotDefaultParams", ctx, req) ret0, _ := ret[0].(error) @@ -625,21 +625,21 @@ func (mr *MockPluginServiceMockRecorder) UpdateBotDefaultParams(ctx, req any) *g } // UpdateDraftPlugin mocks base method. -func (m *MockPluginService) UpdateDraftPlugin(ctx context.Context, arg1 *dto.UpdateDraftPluginRequest) error { +func (m *MockPluginService) UpdateDraftPlugin(ctx context.Context, plugin *dto0.UpdateDraftPluginRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateDraftPlugin", ctx, arg1) + ret := m.ctrl.Call(m, "UpdateDraftPlugin", ctx, plugin) ret0, _ := ret[0].(error) return ret0 } // UpdateDraftPlugin indicates an expected call of UpdateDraftPlugin. -func (mr *MockPluginServiceMockRecorder) UpdateDraftPlugin(ctx, arg1 any) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) UpdateDraftPlugin(ctx, plugin any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDraftPlugin", reflect.TypeOf((*MockPluginService)(nil).UpdateDraftPlugin), ctx, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDraftPlugin", reflect.TypeOf((*MockPluginService)(nil).UpdateDraftPlugin), ctx, plugin) } // UpdateDraftPluginWithCode mocks base method. -func (m *MockPluginService) UpdateDraftPluginWithCode(ctx context.Context, req *dto.UpdateDraftPluginWithCodeRequest) error { +func (m *MockPluginService) UpdateDraftPluginWithCode(ctx context.Context, req *dto0.UpdateDraftPluginWithCodeRequest) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateDraftPluginWithCode", ctx, req) ret0, _ := ret[0].(error) @@ -653,7 +653,7 @@ func (mr *MockPluginServiceMockRecorder) UpdateDraftPluginWithCode(ctx, req any) } // UpdateDraftTool mocks base method. -func (m *MockPluginService) UpdateDraftTool(ctx context.Context, req *dto.UpdateDraftToolRequest) error { +func (m *MockPluginService) UpdateDraftTool(ctx context.Context, req *dto0.UpdateDraftToolRequest) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateDraftTool", ctx, req) ret0, _ := ret[0].(error) diff --git a/backend/types/ddl/gen_orm_query.go b/backend/types/ddl/gen_orm_query.go index 8403a4288..b16c93ae2 100644 --- a/backend/types/ddl/gen_orm_query.go +++ b/backend/types/ddl/gen_orm_query.go @@ -34,7 +34,7 @@ import ( "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/playground" - plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/dto" + plugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" appEntity "github.com/coze-dev/coze-studio/backend/domain/app/entity" variableEntity "github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity" )