Compare commits
3 Commits
chore/setu
...
v0.3.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
| 6e02eed1c8 | |||
| baebc5b148 | |||
| 57ec89d4f3 |
1
.github/workflows/ci.yml
vendored
1
.github/workflows/ci.yml
vendored
@ -15,7 +15,6 @@ permissions:
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
strategy:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,7 +7,6 @@
|
||||
|
||||
.env
|
||||
.env.debug
|
||||
.env.oceanbase
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
26
Makefile
26
Makefile
@ -9,12 +9,9 @@ DUMP_DB_SCRIPT := $(SCRIPTS_DIR)/setup/db_migrate_dump.sh
|
||||
SETUP_DOCKER_SCRIPT := $(SCRIPTS_DIR)/setup/docker.sh
|
||||
SETUP_PYTHON_SCRIPT := $(SCRIPTS_DIR)/setup/python.sh
|
||||
COMPOSE_FILE := docker/docker-compose-debug.yml
|
||||
OCEANBASE_COMPOSE_FILE := docker/docker-compose-oceanbase.yml
|
||||
OCEANBASE_DEBUG_COMPOSE_FILE := docker/docker-compose-oceanbase_debug.yml
|
||||
MYSQL_SCHEMA := ./docker/volumes/mysql/schema.sql
|
||||
MYSQL_INIT_SQL := ./docker/volumes/mysql/sql_init.sql
|
||||
ENV_FILE := ./docker/.env.debug
|
||||
OCEANBASE_ENV_FILE := ./docker/.env.debug
|
||||
STATIC_DIR := ./bin/resources/static
|
||||
ES_INDEX_SCHEMA := ./docker/volumes/elasticsearch/es_index_schema
|
||||
ES_SETUP_SCRIPT := ./docker/volumes/elasticsearch/setup_es.sh
|
||||
@ -39,7 +36,6 @@ server: env
|
||||
@echo "Building and run server..."
|
||||
@APP_ENV=debug bash $(BUILD_SERVER_SCRIPT) -start
|
||||
|
||||
|
||||
build_server:
|
||||
@echo "Building server..."
|
||||
@bash $(BUILD_SERVER_SCRIPT)
|
||||
@ -104,23 +100,6 @@ setup_es_index:
|
||||
@. $(ENV_FILE); \
|
||||
bash $(ES_SETUP_SCRIPT) --index-dir $(ES_INDEX_SCHEMA) --docker-host false --es-address "$$ES_ADDR"
|
||||
|
||||
oceanbase_env:
|
||||
@bash scripts/setup/oceanbase_env.sh debug
|
||||
|
||||
oceanbase_debug: oceanbase_env oceanbase_middleware_debug python oceanbase_server_debug
|
||||
|
||||
oceanbase_middleware_debug:
|
||||
@echo "Starting OceanBase debug middleware..."
|
||||
@docker compose -f $(OCEANBASE_DEBUG_COMPOSE_FILE) --env-file $(ENV_FILE) --profile middleware up -d --wait
|
||||
|
||||
oceanbase_server_debug:
|
||||
@if [ ! -d "$(STATIC_DIR)" ]; then \
|
||||
echo "Static directory '$(STATIC_DIR)' not found, building frontend..."; \
|
||||
$(MAKE) fe; \
|
||||
fi
|
||||
@echo "Building and run OceanBase debug server..."
|
||||
@APP_ENV=debug bash $(BUILD_SERVER_SCRIPT) -start
|
||||
|
||||
help:
|
||||
@echo "Usage: make [target]"
|
||||
@echo ""
|
||||
@ -142,9 +121,4 @@ help:
|
||||
@echo " python - Setup python environment."
|
||||
@echo " atlas-hash - Rehash atlas migration files."
|
||||
@echo " setup_es_index - Setup elasticsearch index."
|
||||
@echo ""
|
||||
@echo "OceanBase Commands:"
|
||||
@echo " oceanbase_env - Setup OceanBase environment file (like 'env')."
|
||||
@echo " oceanbase_debug - Start OceanBase debug environment (like 'debug')."
|
||||
@echo ""
|
||||
@echo " help - Show this help message."
|
||||
|
||||
@ -43,7 +43,6 @@ import (
|
||||
"github.com/cloudwego/hertz/pkg/common/ut"
|
||||
"github.com/cloudwego/hertz/pkg/protocol"
|
||||
"github.com/cloudwego/hertz/pkg/protocol/sse"
|
||||
message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -53,7 +52,6 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
plugin2 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
pluginmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
@ -67,16 +65,10 @@ import (
|
||||
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/application/user"
|
||||
appworkflow "github.com/coze-dev/coze-studio/backend/application/workflow"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun/agentrunmock"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation/conversationmock"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge/knowledgemock"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
mockmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr/modelmock"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
@ -84,9 +76,6 @@ import (
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
|
||||
agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
msgentity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
entity4 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
|
||||
entity3 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
@ -111,7 +100,6 @@ import (
|
||||
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/testutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
@ -140,9 +128,6 @@ type wfTestRunner struct {
|
||||
knowledge *knowledgemock.MockKnowledge
|
||||
database *databasemock.MockDatabase
|
||||
pluginSrv *pluginmock.MockPluginService
|
||||
conversation *conversationmock.MockConversation
|
||||
message *messagemock.MockMessage
|
||||
agentRun *agentrunmock.MockAgentRun
|
||||
internalModel *testutil.UTChatModel
|
||||
publishPatcher *mockey.Mocker
|
||||
ctx context.Context
|
||||
@ -150,33 +135,29 @@ type wfTestRunner struct {
|
||||
}
|
||||
|
||||
var req2URL = map[reflect.Type]string{
|
||||
reflect.TypeOf(&workflow.NodeTemplateListRequest{}): "/api/workflow_api/node_template_list",
|
||||
reflect.TypeOf(&workflow.CreateWorkflowRequest{}): "/api/workflow_api/create",
|
||||
reflect.TypeOf(&workflow.SaveWorkflowRequest{}): "/api/workflow_api/save",
|
||||
reflect.TypeOf(&workflow.DeleteWorkflowRequest{}): "/api/workflow_api/delete",
|
||||
reflect.TypeOf(&workflow.GetCanvasInfoRequest{}): "/api/workflow_api/canvas",
|
||||
reflect.TypeOf(&workflow.WorkFlowTestRunRequest{}): "/api/workflow_api/test_run",
|
||||
reflect.TypeOf(&workflow.CancelWorkFlowRequest{}): "/api/workflow_api/cancel",
|
||||
reflect.TypeOf(&workflow.PublishWorkflowRequest{}): "/api/workflow_api/publish",
|
||||
reflect.TypeOf(&workflow.OpenAPIRunFlowRequest{}): "/v1/workflow/run",
|
||||
reflect.TypeOf(&workflow.ValidateTreeRequest{}): "/api/workflow_api/validate_tree",
|
||||
reflect.TypeOf(&workflow.WorkflowTestResumeRequest{}): "/api/workflow_api/test_resume",
|
||||
reflect.TypeOf(&workflow.WorkflowNodeDebugV2Request{}): "/api/workflow_api/nodeDebug",
|
||||
reflect.TypeOf(&workflow.QueryWorkflowNodeTypeRequest{}): "/api/workflow_api/node_type",
|
||||
reflect.TypeOf(&workflow.GetWorkFlowListRequest{}): "/api/workflow_api/workflow_list",
|
||||
reflect.TypeOf(&workflow.UpdateWorkflowMetaRequest{}): "/api/workflow_api/update_meta",
|
||||
reflect.TypeOf(&workflow.GetWorkflowDetailRequest{}): "/api/workflow_api/workflow_detail",
|
||||
reflect.TypeOf(&workflow.GetWorkflowDetailInfoRequest{}): "/api/workflow_api/workflow_detail_info",
|
||||
reflect.TypeOf(&workflow.GetLLMNodeFCSettingDetailRequest{}): "/api/workflow_api/llm_fc_setting_detail",
|
||||
reflect.TypeOf(&workflow.GetLLMNodeFCSettingsMergedRequest{}): "/api/workflow_api/llm_fc_setting_merged",
|
||||
reflect.TypeOf(&workflow.CopyWorkflowRequest{}): "/api/workflow_api/copy",
|
||||
reflect.TypeOf(&workflow.BatchDeleteWorkflowRequest{}): "/api/workflow_api/batch_delete",
|
||||
reflect.TypeOf(&workflow.GetHistorySchemaRequest{}): "/api/workflow_api/history_schema",
|
||||
reflect.TypeOf(&workflow.GetWorkflowReferencesRequest{}): "/api/workflow_api/workflow_references",
|
||||
reflect.TypeOf(&workflow.CreateProjectConversationDefRequest{}): "/api/workflow_api/project_conversation/create",
|
||||
reflect.TypeOf(&workflow.DeleteProjectConversationDefRequest{}): "/api/workflow_api/project_conversation/delete",
|
||||
reflect.TypeOf(&workflow.UpdateProjectConversationDefRequest{}): "/api/workflow_api/project_conversation/update",
|
||||
reflect.TypeOf(&workflow.ListProjectConversationRequest{}): "/api/workflow_api/project_conversation/list",
|
||||
reflect.TypeOf(&workflow.NodeTemplateListRequest{}): "/api/workflow_api/node_template_list",
|
||||
reflect.TypeOf(&workflow.CreateWorkflowRequest{}): "/api/workflow_api/create",
|
||||
reflect.TypeOf(&workflow.SaveWorkflowRequest{}): "/api/workflow_api/save",
|
||||
reflect.TypeOf(&workflow.DeleteWorkflowRequest{}): "/api/workflow_api/delete",
|
||||
reflect.TypeOf(&workflow.GetCanvasInfoRequest{}): "/api/workflow_api/canvas",
|
||||
reflect.TypeOf(&workflow.WorkFlowTestRunRequest{}): "/api/workflow_api/test_run",
|
||||
reflect.TypeOf(&workflow.CancelWorkFlowRequest{}): "/api/workflow_api/cancel",
|
||||
reflect.TypeOf(&workflow.PublishWorkflowRequest{}): "/api/workflow_api/publish",
|
||||
reflect.TypeOf(&workflow.OpenAPIRunFlowRequest{}): "/v1/workflow/run",
|
||||
reflect.TypeOf(&workflow.ValidateTreeRequest{}): "/api/workflow_api/validate_tree",
|
||||
reflect.TypeOf(&workflow.WorkflowTestResumeRequest{}): "/api/workflow_api/test_resume",
|
||||
reflect.TypeOf(&workflow.WorkflowNodeDebugV2Request{}): "/api/workflow_api/nodeDebug",
|
||||
reflect.TypeOf(&workflow.QueryWorkflowNodeTypeRequest{}): "/api/workflow_api/node_type",
|
||||
reflect.TypeOf(&workflow.GetWorkFlowListRequest{}): "/api/workflow_api/workflow_list",
|
||||
reflect.TypeOf(&workflow.UpdateWorkflowMetaRequest{}): "/api/workflow_api/update_meta",
|
||||
reflect.TypeOf(&workflow.GetWorkflowDetailRequest{}): "/api/workflow_api/workflow_detail",
|
||||
reflect.TypeOf(&workflow.GetWorkflowDetailInfoRequest{}): "/api/workflow_api/workflow_detail_info",
|
||||
reflect.TypeOf(&workflow.GetLLMNodeFCSettingDetailRequest{}): "/api/workflow_api/llm_fc_setting_detail",
|
||||
reflect.TypeOf(&workflow.GetLLMNodeFCSettingsMergedRequest{}): "/api/workflow_api/llm_fc_setting_merged",
|
||||
reflect.TypeOf(&workflow.CopyWorkflowRequest{}): "/api/workflow_api/copy",
|
||||
reflect.TypeOf(&workflow.BatchDeleteWorkflowRequest{}): "/api/workflow_api/batch_delete",
|
||||
reflect.TypeOf(&workflow.GetHistorySchemaRequest{}): "/api/workflow_api/history_schema",
|
||||
reflect.TypeOf(&workflow.GetWorkflowReferencesRequest{}): "/api/workflow_api/workflow_references",
|
||||
}
|
||||
|
||||
func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
@ -217,10 +198,6 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
h.GET("/v1/workflow/get_run_history", OpenAPIGetWorkflowRunHistory)
|
||||
h.POST("/api/workflow_api/history_schema", GetHistorySchema)
|
||||
h.POST("/api/workflow_api/workflow_references", GetWorkflowReferences)
|
||||
h.POST("/api/workflow_api/project_conversation/create", CreateProjectConversationDef)
|
||||
h.POST("/api/workflow_api/project_conversation/delete", DeleteProjectConversationDef)
|
||||
h.POST("/api/workflow_api/project_conversation/update", UpdateProjectConversationDef)
|
||||
h.POST("/api/workflow_api/project_conversation/list", ListProjectConversationDef)
|
||||
|
||||
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
|
||||
mockIDGen := mock.NewMockIDGenerator(ctrl)
|
||||
@ -325,13 +302,6 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
mockPluginSrv := pluginmock.NewMockPluginService(ctrl)
|
||||
crossplugin.SetDefaultSVC(mockPluginSrv)
|
||||
|
||||
mockConversation := conversationmock.NewMockConversation(ctrl)
|
||||
crossconversation.SetDefaultSVC(mockConversation)
|
||||
mockMessage := messagemock.NewMockMessage(ctrl)
|
||||
crossmessage.SetDefaultSVC(mockMessage)
|
||||
mockAgentRun := agentrunmock.NewMockAgentRun(ctrl)
|
||||
crossagentrun.SetDefaultSVC(mockAgentRun)
|
||||
|
||||
mockey.Mock((*user.UserApplicationService).MGetUserBasicInfo).Return(&playground.MGetUserBasicInfoResponse{
|
||||
UserBasicInfoMap: make(map[string]*playground.UserBasicInfo),
|
||||
}, nil).Build()
|
||||
@ -367,9 +337,6 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
closeFn: f,
|
||||
pluginSrv: mockPluginSrv,
|
||||
publishPatcher: publishPatcher,
|
||||
conversation: mockConversation,
|
||||
message: mockMessage,
|
||||
agentRun: mockAgentRun,
|
||||
}
|
||||
}
|
||||
|
||||
@ -646,42 +613,12 @@ func mustMarshalToString(t *testing.T, m any) string {
|
||||
return b
|
||||
}
|
||||
|
||||
type runOption struct {
|
||||
ProjectID *int64
|
||||
BotID *int64
|
||||
}
|
||||
type RunOptionFun func(options *runOption)
|
||||
|
||||
func withRunProjectID(pID int64) RunOptionFun {
|
||||
return func(options *runOption) {
|
||||
options.ProjectID = &pID
|
||||
}
|
||||
}
|
||||
|
||||
func withRunBotID(bID int64) RunOptionFun {
|
||||
return func(options *runOption) {
|
||||
options.BotID = &bID
|
||||
}
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) testRun(id string, input map[string]string, opts ...RunOptionFun) string {
|
||||
opt := &runOption{}
|
||||
for _, o := range opts {
|
||||
o(opt)
|
||||
}
|
||||
func (r *wfTestRunner) testRun(id string, input map[string]string) string {
|
||||
testRunReq := &workflow.WorkFlowTestRunRequest{
|
||||
WorkflowID: id,
|
||||
Input: input,
|
||||
}
|
||||
|
||||
if opt.ProjectID != nil {
|
||||
testRunReq.ProjectID = ptr.Of(strconv.FormatInt(ptr.From(opt.ProjectID), 10))
|
||||
}
|
||||
|
||||
if opt.BotID != nil {
|
||||
testRunReq.BotID = ptr.Of(strconv.FormatInt(ptr.From(opt.BotID), 10))
|
||||
}
|
||||
|
||||
testRunResponse := post[workflow.WorkFlowTestRunResponse](r, testRunReq)
|
||||
return testRunResponse.Data.ExecuteID
|
||||
}
|
||||
@ -829,26 +766,13 @@ func (r *wfTestRunner) openapiAsyncRun(id string, input any) string {
|
||||
return runResp.GetExecuteID()
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) openapiSyncRun(id string, input any, opts ...RunOptionFun) (map[string]any, string) {
|
||||
opt := &runOption{}
|
||||
for _, o := range opts {
|
||||
o(opt)
|
||||
}
|
||||
|
||||
func (r *wfTestRunner) openapiSyncRun(id string, input any) (map[string]any, string) {
|
||||
runReq := &workflow.OpenAPIRunFlowRequest{
|
||||
WorkflowID: id,
|
||||
Parameters: ptr.Of(mustMarshalToString(r.t, input)),
|
||||
IsAsync: ptr.Of(false),
|
||||
}
|
||||
|
||||
if opt.ProjectID != nil {
|
||||
runReq.ProjectID = ptr.Of(strconv.FormatInt(ptr.From(opt.ProjectID), 10))
|
||||
}
|
||||
|
||||
if opt.BotID != nil {
|
||||
runReq.BotID = ptr.Of(strconv.FormatInt(ptr.From(opt.BotID), 10))
|
||||
}
|
||||
|
||||
runResp := post[workflow.OpenAPIRunFlowResponse](r, runReq)
|
||||
output := runResp.GetData()
|
||||
var m map[string]any
|
||||
@ -4834,380 +4758,3 @@ func TestHttpImplicitDependencies(t *testing.T) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageNodes(t *testing.T) {
|
||||
mockey.PatchConvey("create message in dynamic conversation", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
idStr := r.load("message/create_message.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
|
||||
ret, _ := r.openapiSyncRun(idStr, map[string]string{
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}, withRunProjectID(123))
|
||||
assert.Equal(t, true, ret["output"])
|
||||
assert.Equal(t, strconv.FormatInt(mID, 10), ret["mID"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("create message in static conversation", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
createReq := &workflow.CreateProjectConversationDefRequest{
|
||||
ProjectID: "123",
|
||||
ConversationName: "name" + strconv.FormatInt(cID, 10),
|
||||
SpaceID: "123",
|
||||
}
|
||||
post[workflow.CreateProjectConversationDefResponse](r, createReq)
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
idStr := r.load("message/create_message.json")
|
||||
testInput := map[string]string{
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}
|
||||
exeID := r.testRun(idStr, testInput, withRunProjectID(123))
|
||||
e := r.getProcess(idStr, exeID)
|
||||
e.assertSuccess()
|
||||
output := e.output
|
||||
var result map[string]any
|
||||
err := sonic.Unmarshal([]byte(output), &result)
|
||||
assert.NoError(t, err, "Failed to unmarshal output JSON")
|
||||
|
||||
assert.Equal(t, true, result["output"])
|
||||
assert.Equal(t, strconv.FormatInt(mID, 10), result["mID"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("create message in Bot scene", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
idStr := r.load("message/create_message_in_agent.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
|
||||
testInput := map[string]string{
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}
|
||||
exeID := r.testRun(idStr, testInput, withRunBotID(123))
|
||||
e := r.getProcess(idStr, exeID)
|
||||
assert.Equal(t, e.status, workflow.WorkflowExeStatus_Fail)
|
||||
assert.Contains(t, e.reason, "Only default conversation allow in agent scenario")
|
||||
})
|
||||
|
||||
mockey.PatchConvey("create message without binding app nor bot", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
idStr := r.load("message/create_message_in_agent.json")
|
||||
|
||||
testInput := map[string]string{
|
||||
"CONVERSATION_NAME": "Default",
|
||||
}
|
||||
exeID := r.testRun(idStr, testInput)
|
||||
e := r.getProcess(idStr, exeID)
|
||||
output := e.output
|
||||
var result map[string]any
|
||||
err := sonic.Unmarshal([]byte(output), &result)
|
||||
assert.NoError(t, err, "Failed to unmarshal output JSON")
|
||||
|
||||
assert.Equal(t, false, result["isSuccess"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("query message list in dynamic conversation", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().MessageList(gomock.Any(), gomock.Any()).Return(&message0.MessageListResponse{
|
||||
Messages: []*message0.WfMessage{
|
||||
{
|
||||
ID: mID,
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Text: ptr.Of("hello"),
|
||||
},
|
||||
},
|
||||
}, nil).AnyTimes()
|
||||
|
||||
idStr := r.load("message/message_list.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
ret, _ := r.openapiSyncRun(idStr, map[string]string{
|
||||
"USER_INPUT": "hello",
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}, withRunProjectID(123))
|
||||
|
||||
mIDStr := strconv.FormatInt(mID, 10)
|
||||
expected := []any{
|
||||
map[string]any{
|
||||
"messageId": mIDStr,
|
||||
"role": "user",
|
||||
"contentType": "text",
|
||||
"content": "hello",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expected, ret["output"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("query message list in static conversation", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
createReq := &workflow.CreateProjectConversationDefRequest{
|
||||
ProjectID: "123",
|
||||
ConversationName: "name" + strconv.FormatInt(cID, 10),
|
||||
SpaceID: "123",
|
||||
}
|
||||
post[workflow.CreateProjectConversationDefResponse](r, createReq)
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().MessageList(gomock.Any(), gomock.Any()).Return(&message0.MessageListResponse{
|
||||
Messages: []*message0.WfMessage{
|
||||
{
|
||||
ID: mID,
|
||||
Role: "user",
|
||||
ContentType: "text",
|
||||
Text: ptr.Of("hello"),
|
||||
},
|
||||
},
|
||||
}, nil).AnyTimes()
|
||||
|
||||
idStr := r.load("message/message_list.json")
|
||||
testInput := map[string]string{
|
||||
"USER_INPUT": "hello",
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}
|
||||
exeID := r.testRun(idStr, testInput, withRunProjectID(123))
|
||||
e := r.getProcess(idStr, exeID)
|
||||
e.assertSuccess()
|
||||
output := e.output
|
||||
var result map[string]any
|
||||
err := sonic.Unmarshal([]byte(output), &result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mIDStr := strconv.FormatInt(mID, 10)
|
||||
expected := []any{
|
||||
map[string]any{
|
||||
"messageId": mIDStr,
|
||||
"role": "user",
|
||||
"contentType": "text",
|
||||
"content": "hello",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expected, result["output"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("edit message in dynamic conversation", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().Edit(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
ConversationID: cID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().GetMessageByID(gomock.Any(), gomock.Any()).Return(&msgentity.Message{
|
||||
ID: mID,
|
||||
ConversationID: cID,
|
||||
Content: "123",
|
||||
}, nil).AnyTimes()
|
||||
|
||||
idStr := r.load("message/edit_message.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
ret, _ := r.openapiSyncRun(idStr, map[string]string{
|
||||
"USER_INPUT": "hello",
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}, withRunProjectID(123))
|
||||
|
||||
assert.Equal(t, true, ret["isSuccess"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("edit message in static conversation", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
createReq := &workflow.CreateProjectConversationDefRequest{
|
||||
ProjectID: "123",
|
||||
ConversationName: "name" + strconv.FormatInt(cID, 10),
|
||||
SpaceID: "123",
|
||||
}
|
||||
post[workflow.CreateProjectConversationDefResponse](r, createReq)
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().Edit(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
ConversationID: cID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().GetMessageByID(gomock.Any(), gomock.Any()).Return(&msgentity.Message{
|
||||
ID: mID,
|
||||
ConversationID: cID,
|
||||
Content: "123",
|
||||
}, nil).AnyTimes()
|
||||
|
||||
idStr := r.load("message/edit_message.json")
|
||||
testInput := map[string]string{
|
||||
"USER_INPUT": "hello",
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}
|
||||
exeID := r.testRun(idStr, testInput, withRunProjectID(123))
|
||||
e := r.getProcess(idStr, exeID)
|
||||
e.assertSuccess()
|
||||
output := e.output
|
||||
var result map[string]any
|
||||
err := sonic.Unmarshal([]byte(output), &result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, true, result["isSuccess"])
|
||||
})
|
||||
|
||||
mockey.PatchConvey("edit message no permission", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
err := errorx.New(errno.ErrMessageNodeOperationFail, errorx.KV("cause", "message not found"))
|
||||
r.message.EXPECT().Edit(gomock.Any(), gomock.Any()).Return(&message.Message{}, err).AnyTimes()
|
||||
r.message.EXPECT().GetMessageByID(gomock.Any(), gomock.Any()).Return(&msgentity.Message{
|
||||
ConversationID: cID,
|
||||
Content: "123456",
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
|
||||
idStr := r.load("message/edit_message_no_permission.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
|
||||
testInput := map[string]string{
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}
|
||||
exeID := r.testRun(idStr, testInput, withRunProjectID(123))
|
||||
e := r.getProcess(idStr, exeID)
|
||||
assert.Equal(t, e.status, workflow.WorkflowExeStatus_Fail)
|
||||
assert.Contains(t, e.reason, "Message node operation failure: message not found")
|
||||
})
|
||||
|
||||
mockey.PatchConvey("delete message", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
cID := time.Now().Unix()
|
||||
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
}, nil).AnyTimes()
|
||||
mID := time.Now().Unix()
|
||||
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
|
||||
ID: mID,
|
||||
}, nil).AnyTimes()
|
||||
rID := time.Now().UnixNano()
|
||||
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
|
||||
ID: rID,
|
||||
}, nil).AnyTimes()
|
||||
sID := time.Now().UnixNano()
|
||||
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
|
||||
ID: cID,
|
||||
SectionID: sID,
|
||||
}, nil).AnyTimes()
|
||||
r.message.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
idStr := r.load("message/delete_message.json")
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
ret, _ := r.openapiSyncRun(idStr, map[string]string{
|
||||
"USER_INPUT": "hello",
|
||||
"CONVERSATION_NAME": "name" + strconv.FormatInt(cID, 10),
|
||||
}, withRunProjectID(123))
|
||||
|
||||
assert.Equal(t, true, ret["isSuccess"])
|
||||
})
|
||||
}
|
||||
|
||||
@ -329,8 +329,7 @@ type CreateDocumentResponse struct {
|
||||
}
|
||||
|
||||
type DeleteDocumentRequest struct {
|
||||
DocumentID int64
|
||||
KnowledgeID int64
|
||||
DocumentID string
|
||||
}
|
||||
|
||||
type DeleteDocumentResponse struct {
|
||||
|
||||
@ -62,7 +62,6 @@ import (
|
||||
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
|
||||
embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
|
||||
@ -73,7 +72,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
|
||||
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
|
||||
oceanbaseClient "github.com/coze-dev/coze-studio/backend/infra/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
@ -524,86 +522,6 @@ func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
|
||||
|
||||
return mgr, nil
|
||||
|
||||
case "oceanbase":
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
host = os.Getenv("OCEANBASE_HOST")
|
||||
port = os.Getenv("OCEANBASE_PORT")
|
||||
user = os.Getenv("OCEANBASE_USER")
|
||||
password = os.Getenv("OCEANBASE_PASSWORD")
|
||||
database = os.Getenv("OCEANBASE_DATABASE")
|
||||
)
|
||||
if host == "" || port == "" || user == "" || password == "" || database == "" {
|
||||
return nil, fmt.Errorf("invalid oceanbase configuration: host, port, user, password, database are required")
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
user, password, host, port, database)
|
||||
|
||||
client, err := oceanbaseClient.NewOceanBaseClient(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase client failed, err=%w", err)
|
||||
}
|
||||
|
||||
if err := client.InitDatabase(ctx); err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase database failed, err=%w", err)
|
||||
}
|
||||
|
||||
// Get configuration from environment variables with defaults
|
||||
batchSize := 100
|
||||
if bs := os.Getenv("OCEANBASE_BATCH_SIZE"); bs != "" {
|
||||
if bsInt, err := strconv.Atoi(bs); err == nil {
|
||||
batchSize = bsInt
|
||||
}
|
||||
}
|
||||
|
||||
enableCache := true
|
||||
if ec := os.Getenv("OCEANBASE_ENABLE_CACHE"); ec != "" {
|
||||
if ecBool, err := strconv.ParseBool(ec); err == nil {
|
||||
enableCache = ecBool
|
||||
}
|
||||
}
|
||||
|
||||
cacheTTL := 300 * time.Second
|
||||
if ct := os.Getenv("OCEANBASE_CACHE_TTL"); ct != "" {
|
||||
if ctInt, err := strconv.Atoi(ct); err == nil {
|
||||
cacheTTL = time.Duration(ctInt) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
maxConnections := 100
|
||||
if mc := os.Getenv("OCEANBASE_MAX_CONNECTIONS"); mc != "" {
|
||||
if mcInt, err := strconv.Atoi(mc); err == nil {
|
||||
maxConnections = mcInt
|
||||
}
|
||||
}
|
||||
|
||||
connTimeout := 30 * time.Second
|
||||
if ct := os.Getenv("OCEANBASE_CONN_TIMEOUT"); ct != "" {
|
||||
if ctInt, err := strconv.Atoi(ct); err == nil {
|
||||
connTimeout = time.Duration(ctInt) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
managerConfig := &oceanbase.ManagerConfig{
|
||||
Client: client,
|
||||
Embedding: emb,
|
||||
BatchSize: batchSize,
|
||||
EnableCache: enableCache,
|
||||
CacheTTL: cacheTTL,
|
||||
MaxConnections: maxConnections,
|
||||
ConnTimeout: connTimeout,
|
||||
}
|
||||
mgr, err := oceanbase.NewManager(managerConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase vector store failed, err=%w", err)
|
||||
}
|
||||
return mgr, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
|
||||
@ -127,7 +127,7 @@ func (a *OpenapiAgentRunApplication) checkAgent(ctx context.Context, ar *run.Cha
|
||||
}
|
||||
|
||||
if agentInfo == nil {
|
||||
return nil, errorx.New(errno.ErrAgentNotExists)
|
||||
return nil, errors.New("agent info is nil")
|
||||
}
|
||||
return agentInfo, nil
|
||||
}
|
||||
|
||||
@ -291,13 +291,12 @@ func (v *VariableApplicationService) DeleteVariableInstance(ctx context.Context,
|
||||
|
||||
bizType := ternary.IFElse(req.BotID == 0, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
|
||||
bizID := ternary.IFElse(req.BotID == 0, req.ProjectID, fmt.Sprintf("%d", req.BotID))
|
||||
connectId := ternary.IFElse(req.ConnectorID == nil, consts.CozeConnectorID, req.GetConnectorID())
|
||||
|
||||
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
|
||||
BizType: bizType,
|
||||
BizID: bizID,
|
||||
Version: "",
|
||||
ConnectorID: connectId,
|
||||
ConnectorID: req.GetConnectorID(),
|
||||
ConnectorUID: fmt.Sprintf("%d", *uid),
|
||||
})
|
||||
|
||||
|
||||
@ -730,18 +730,15 @@ func (s *SingleAgentApplicationService) getAgentInfo(ctx context.Context, botID
|
||||
AgentID: ptr.Of(si.ObjectID),
|
||||
Command: si.ShortcutCommand,
|
||||
Components: slices.Transform(si.Components, func(i *playground.Components) *bot_common.ShortcutCommandComponent {
|
||||
sc := &bot_common.ShortcutCommandComponent{
|
||||
return &bot_common.ShortcutCommandComponent{
|
||||
Name: i.Name,
|
||||
Description: i.Description,
|
||||
Type: i.InputType.String(),
|
||||
ToolParameter: ptr.Of(i.Parameter),
|
||||
Options: i.Options,
|
||||
DefaultValue: ptr.Of(i.DefaultValue.Value),
|
||||
IsHide: i.Hide,
|
||||
}
|
||||
if i.DefaultValue != nil {
|
||||
sc.DefaultValue = ptr.Of(i.DefaultValue.Value)
|
||||
}
|
||||
return sc
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
@ -20,9 +20,8 @@ import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
@ -93,7 +92,7 @@ func InitService(_ context.Context, components *ServiceComponents) (*Application
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
workflow.SetRepository(workflowRepo)
|
||||
|
||||
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
||||
|
||||
@ -3631,25 +3631,7 @@ func toWorkflowAPIParameterAssistType(ty vo.FileSubType) workflow.AssistParamete
|
||||
}
|
||||
}
|
||||
|
||||
func toVariableSlice(params []*workflow.APIParameter) ([]*vo.Variable, error) {
|
||||
if len(params) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
res := make([]*vo.Variable, 0, len(params))
|
||||
for _, p := range params {
|
||||
v, err := toVariable(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res = append(res, v)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func toVariable(p *workflow.APIParameter) (*vo.Variable, error) {
|
||||
if p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
v := &vo.Variable{
|
||||
Name: p.Name,
|
||||
Description: p.Desc,
|
||||
@ -3671,33 +3653,38 @@ func toVariable(p *workflow.APIParameter) (*vo.Variable, error) {
|
||||
v.Type = vo.VariableTypeBoolean
|
||||
case workflow.ParameterType_Array:
|
||||
v.Type = vo.VariableTypeList
|
||||
if p.SubType == nil {
|
||||
return nil, fmt.Errorf("array parameter '%s' is missing a SubType", p.Name)
|
||||
}
|
||||
// The schema of a list variable is a single variable describing the items.
|
||||
itemSchema := &vo.Variable{
|
||||
Type: vo.VariableType(strings.ToLower(p.SubType.String())),
|
||||
}
|
||||
// If the items in the array are objects, describe their structure.
|
||||
if *p.SubType == workflow.ParameterType_Object {
|
||||
itemFields, err := toVariableSlice(p.SubParameters)
|
||||
if len(p.SubParameters) == 1 && p.SubType != nil && *p.SubType != workflow.ParameterType_Object {
|
||||
av, err := toVariable(p.SubParameters[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
itemSchema.Schema = itemFields
|
||||
v.Schema = &av
|
||||
} else {
|
||||
if len(p.SubParameters) > 0 && p.SubParameters[0].AssistType != nil {
|
||||
itemSchema.AssistType = vo.AssistType(*p.SubParameters[0].AssistType)
|
||||
subVs := make([]any, 0)
|
||||
for _, ap := range p.SubParameters {
|
||||
av, err := toVariable(ap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subVs = append(subVs, av)
|
||||
}
|
||||
v.Schema = &vo.Variable{
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: subVs,
|
||||
}
|
||||
}
|
||||
v.Schema = itemSchema
|
||||
case workflow.ParameterType_Object:
|
||||
v.Type = vo.VariableTypeObject
|
||||
subVars, err := toVariableSlice(p.SubParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
vs := make([]*vo.Variable, 0)
|
||||
for _, v := range p.SubParameters {
|
||||
objV, err := toVariable(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vs = append(vs, objV)
|
||||
|
||||
}
|
||||
v.Schema = subVars
|
||||
v.Schema = vs
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown workflow api parameter type: %v", p.Type)
|
||||
}
|
||||
|
||||
@ -1,158 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToVariable(t *testing.T) {
|
||||
fileAssistType := workflow.AssistParameterType_DEFAULT
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input *workflow.APIParameter
|
||||
expected *vo.Variable
|
||||
expectErr bool
|
||||
expectedErrAs any
|
||||
}{
|
||||
{
|
||||
name: "Nil Input",
|
||||
input: nil,
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Simple String",
|
||||
input: &workflow.APIParameter{
|
||||
Name: "prompt", Type: workflow.ParameterType_String, IsRequired: true,
|
||||
},
|
||||
expected: &vo.Variable{
|
||||
Name: "prompt", Type: vo.VariableTypeString, Required: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Simple Object",
|
||||
input: &workflow.APIParameter{
|
||||
Name: "user",
|
||||
Type: workflow.ParameterType_Object,
|
||||
SubParameters: []*workflow.APIParameter{
|
||||
{Name: "name", Type: workflow.ParameterType_String},
|
||||
{Name: "age", Type: workflow.ParameterType_Integer},
|
||||
},
|
||||
},
|
||||
expected: &vo.Variable{
|
||||
Name: "user",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "name", Type: vo.VariableTypeString},
|
||||
{Name: "age", Type: vo.VariableTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Array of Objects",
|
||||
input: &workflow.APIParameter{
|
||||
Name: "items",
|
||||
Type: workflow.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow.ParameterType_Object),
|
||||
SubParameters: []*workflow.APIParameter{
|
||||
{Name: "id", Type: workflow.ParameterType_String},
|
||||
{Name: "price", Type: workflow.ParameterType_Number},
|
||||
},
|
||||
},
|
||||
expected: &vo.Variable{
|
||||
Name: "items",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "id", Type: vo.VariableTypeString},
|
||||
{Name: "price", Type: vo.VariableTypeFloat},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Array of Primitives (File)",
|
||||
input: &workflow.APIParameter{
|
||||
Name: "attachments",
|
||||
Type: workflow.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow.ParameterType_String),
|
||||
SubParameters: []*workflow.APIParameter{
|
||||
{AssistType: &fileAssistType},
|
||||
},
|
||||
},
|
||||
expected: &vo.Variable{
|
||||
Name: "attachments",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Type: vo.VariableTypeString,
|
||||
AssistType: vo.AssistType(fileAssistType),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Array of Primitives (String)",
|
||||
input: &workflow.APIParameter{
|
||||
Name: "tags",
|
||||
Type: workflow.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow.ParameterType_String),
|
||||
},
|
||||
expected: &vo.Variable{
|
||||
Name: "tags",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Type: vo.VariableTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Array with missing SubType",
|
||||
input: &workflow.APIParameter{
|
||||
Name: "bad_array",
|
||||
Type: workflow.ParameterType_Array,
|
||||
},
|
||||
expectErr: true,
|
||||
expectedErrAs: "missing a SubType",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual, err := toVariable(tc.input)
|
||||
|
||||
if tc.expectErr {
|
||||
require.Error(t, err)
|
||||
if tc.expectedErrAs != nil {
|
||||
assert.True(t, strings.Contains(err.Error(), fmt.Sprint(tc.expectedErrAs)))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -355,7 +355,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"im:message im:message.group_msg offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"im:message im:message.group_msg","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_message.png
|
||||
api:
|
||||
type: openapi
|
||||
@ -437,7 +437,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"bitable:app wiki:wiki offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"bitable:app wiki:wiki","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_base.png
|
||||
api:
|
||||
type: openapi
|
||||
@ -549,7 +549,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"drive:drive offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"drive:drive","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_sheet.png
|
||||
api:
|
||||
type: openapi
|
||||
@ -637,7 +637,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"task:task:write offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"task:task:write","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_task.png
|
||||
api:
|
||||
type: openapi
|
||||
@ -711,7 +711,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"drive:drive offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"drive:drive","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_docx.png
|
||||
api:
|
||||
type: openapi
|
||||
@ -775,7 +775,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"wiki:wiki offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"wiki:wiki","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_wiki.png
|
||||
api:
|
||||
type: openapi
|
||||
@ -807,7 +807,7 @@
|
||||
type: oauth
|
||||
sub_type: authorization_code
|
||||
# client_id and client_secret apply to https://open.larkoffice.com/app
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"calendar:calendar calendar:calendar:read offline_access","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
payload: '{"client_id":"","client_secret":"","client_url":"https://accounts.feishu.cn/open-apis/authen/v1/authorize","scope":"calendar:calendar calendar:calendar:read","authorization_url":"https://open.larkoffice.com/open-apis/authen/v2/oauth/token","authorization_content_type":"application/json"}'
|
||||
logo_url: official_plugin_icon/plugin_lark_calendar.png
|
||||
api:
|
||||
type: openapi
|
||||
|
||||
@ -22,7 +22,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination agentrunmock/agent_run_mock.go --package agentrunmock -source agent_run.go
|
||||
type AgentRun interface {
|
||||
Delete(ctx context.Context, runID []int64) error
|
||||
List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: agent_run.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination agentrunmock/agent_run_mock.go --package agentrunmock -source agent_run.go
|
||||
//
|
||||
|
||||
// Package agentrunmock is a generated GoMock package.
|
||||
package agentrunmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
entity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockAgentRun is a mock of AgentRun interface.
|
||||
type MockAgentRun struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAgentRunMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockAgentRunMockRecorder is the mock recorder for MockAgentRun.
|
||||
type MockAgentRunMockRecorder struct {
|
||||
mock *MockAgentRun
|
||||
}
|
||||
|
||||
// NewMockAgentRun creates a new mock instance.
|
||||
func NewMockAgentRun(ctrl *gomock.Controller) *MockAgentRun {
|
||||
mock := &MockAgentRun{ctrl: ctrl}
|
||||
mock.recorder = &MockAgentRunMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockAgentRun) EXPECT() *MockAgentRunMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockAgentRun) Create(ctx context.Context, runRecord *entity.AgentRunMeta) (*entity.RunRecordMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", ctx, runRecord)
|
||||
ret0, _ := ret[0].(*entity.RunRecordMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockAgentRunMockRecorder) Create(ctx, runRecord any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockAgentRun)(nil).Create), ctx, runRecord)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockAgentRun) Delete(ctx context.Context, runID []int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, runID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockAgentRunMockRecorder) Delete(ctx, runID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockAgentRun)(nil).Delete), ctx, runID)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockAgentRun) List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List", ctx, ListMeta)
|
||||
ret0, _ := ret[0].([]*entity.RunRecordMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockAgentRunMockRecorder) List(ctx, ListMeta any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockAgentRun)(nil).List), ctx, ListMeta)
|
||||
}
|
||||
@ -20,7 +20,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
@ -28,9 +29,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
var defaultSVC crossknowledge.Knowledge
|
||||
@ -133,24 +132,13 @@ func (i *impl) Store(ctx context.Context, document *model.CreateDocumentRequest)
|
||||
}
|
||||
|
||||
func (i *impl) Delete(ctx context.Context, r *model.DeleteDocumentRequest) (*model.DeleteDocumentResponse, error) {
|
||||
if r.KnowledgeID == 0 {
|
||||
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "knowledge id cannot be 0"))
|
||||
}
|
||||
|
||||
docs, err := i.DomainSVC.ListDocument(ctx, &service.ListDocumentRequest{
|
||||
KnowledgeID: r.KnowledgeID,
|
||||
DocumentIDs: []int64{r.DocumentID},
|
||||
SelectAll: true,
|
||||
})
|
||||
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(docs.Documents) == 0 {
|
||||
return nil, errorx.New(errno.ErrKnowledgeDocumentNotExistCode, errorx.KV("msg", "the specified document is not part of this knowledge base"))
|
||||
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
|
||||
}
|
||||
|
||||
err = i.DomainSVC.DeleteDocument(ctx, &service.DeleteDocumentRequest{
|
||||
DocumentID: r.DocumentID,
|
||||
DocumentID: docID,
|
||||
})
|
||||
if err != nil {
|
||||
return &model.DeleteDocumentResponse{IsSuccess: false}, err
|
||||
|
||||
@ -22,7 +22,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
|
||||
@ -21,9 +21,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
@ -31,6 +28,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockWorkflowRepo struct {
|
||||
|
||||
@ -571,6 +571,7 @@ func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParame
|
||||
if parameter.SubType != nil {
|
||||
p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType))
|
||||
}
|
||||
|
||||
if parameter.DefaultParamSource != nil {
|
||||
p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource))
|
||||
}
|
||||
@ -578,22 +579,23 @@ func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParame
|
||||
p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType))
|
||||
}
|
||||
|
||||
// Check if it's a specially wrapped array that needs unwrapping.
|
||||
// Check if it's an array that needs unwrapping.
|
||||
if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" {
|
||||
arrayItem := parameter.SubParameters[0]
|
||||
// The actual type of array elements is the type of the "[Array Item]".
|
||||
p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type))
|
||||
// If the array elements are objects, their sub-parameters (fields) are lifted up.
|
||||
// If the "[Array Item]" is an object, its sub-parameters become the array's sub-parameters.
|
||||
if arrayItem.Type == common.ParameterType_Object {
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters))
|
||||
for _, subParam := range arrayItem.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
|
||||
}
|
||||
} else {
|
||||
// The array's SubType is the Type of the "[Array Item]".
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, 1)
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem))
|
||||
p.SubParameters[0].Name = "" // Remove the "[Array Item]" name.
|
||||
}
|
||||
} else if len(parameter.SubParameters) > 0 {
|
||||
} else if len(parameter.SubParameters) > 0 { // A simple object or a non-wrapped array.
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters))
|
||||
for _, subParam := range parameter.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
|
||||
|
||||
@ -1,187 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestToWorkflowAPIParameter(t *testing.T) {
|
||||
fileAssistType := common.AssistParameterType_DEFAULT
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input *common.APIParameter
|
||||
expected *workflow3.APIParameter
|
||||
}{
|
||||
{
|
||||
name: "Nil Input",
|
||||
input: nil,
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Simple String Parameter",
|
||||
input: &common.APIParameter{
|
||||
Name: "prompt",
|
||||
Type: common.ParameterType_String,
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "prompt",
|
||||
Type: workflow3.ParameterType_String,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Simple Object Parameter",
|
||||
input: &common.APIParameter{
|
||||
Name: "user",
|
||||
Type: common.ParameterType_Object,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{Name: "name", Type: common.ParameterType_String},
|
||||
{Name: "age", Type: common.ParameterType_Integer},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "user",
|
||||
Type: workflow3.ParameterType_Object,
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{Name: "name", Type: workflow3.ParameterType_String},
|
||||
{Name: "age", Type: workflow3.ParameterType_Integer},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Wrapped Array of Primitives (String)",
|
||||
input: &common.APIParameter{
|
||||
Name: "tags",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_String,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "tags",
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_String),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: workflow3.ParameterType_String,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Wrapped Array of Primitives with AssistType (File)",
|
||||
input: &common.APIParameter{
|
||||
Name: "documents",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_String,
|
||||
AssistType: &fileAssistType,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "documents",
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_String),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: workflow3.ParameterType_String,
|
||||
AssistType: ptr.Of(workflow3.AssistParameterType(fileAssistType)),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Wrapped Array of Objects",
|
||||
input: &common.APIParameter{
|
||||
Name: "users",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_Object,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{Name: "name", Type: common.ParameterType_String},
|
||||
{Name: "email", Type: common.ParameterType_String},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "users",
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_Object),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{Name: "name", Type: workflow3.ParameterType_String},
|
||||
{Name: "email", Type: workflow3.ParameterType_String},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := toWorkflowAPIParameter(tc.input)
|
||||
|
||||
// Use require for nil checks to stop test early if it fails
|
||||
if tc.expected == nil {
|
||||
assert.Nil(t, actual)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, actual)
|
||||
|
||||
assert.Equal(t, tc.expected.Name, actual.Name, "Name should match")
|
||||
assert.Equal(t, tc.expected.Type, actual.Type, "Type should match")
|
||||
|
||||
if tc.expected.SubType != nil {
|
||||
assert.NotNil(t, actual.SubType, "SubType should not be nil")
|
||||
assert.Equal(t, *tc.expected.SubType, *actual.SubType, "SubType value should match")
|
||||
} else {
|
||||
assert.Nil(t, actual.SubType, "SubType should be nil")
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tc.expected.SubParameters), len(actual.SubParameters), "Number of sub-parameters should match")
|
||||
|
||||
for i := range tc.expected.SubParameters {
|
||||
expectedSub := tc.expected.SubParameters[i]
|
||||
actualSub := actual.SubParameters[i]
|
||||
assert.Equal(t, expectedSub.Name, actualSub.Name, "Sub-parameter name should match")
|
||||
assert.Equal(t, expectedSub.Type, actualSub.Type, "Sub-parameter type should match")
|
||||
|
||||
if expectedSub.AssistType != nil {
|
||||
assert.NotNil(t, actualSub.AssistType, "Sub-parameter AssistType should not be nil")
|
||||
assert.Equal(t, *expectedSub.AssistType, *actualSub.AssistType, "Sub-parameter AssistType value should match")
|
||||
} else {
|
||||
assert.Nil(t, actualSub.AssistType, "Sub-parameter AssistType should be nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -97,8 +97,5 @@ func (c *impl) ObtainAgentByIdentity(ctx context.Context, identity *model.AgentI
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if agentInfo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return agentInfo.SingleAgent, nil
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@ package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
@ -21,9 +21,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func getAgentHistoryRounds(agentInfo *singleagent.SingleAgent) int32 {
|
||||
@ -34,18 +32,14 @@ func getAgentHistoryRounds(agentInfo *singleagent.SingleAgent) int32 {
|
||||
return conversationTurns
|
||||
}
|
||||
|
||||
func getAgentInfo(ctx context.Context, agentID int64, isDraft bool, connID int64) (*singleagent.SingleAgent, error) {
|
||||
func getAgentInfo(ctx context.Context, agentID int64, isDraft bool) (*singleagent.SingleAgent, error) {
|
||||
agentInfo, err := crossagent.DefaultSVC().ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
|
||||
AgentID: agentID,
|
||||
IsDraft: isDraft,
|
||||
ConnectorID: connID,
|
||||
AgentID: agentID,
|
||||
IsDraft: isDraft,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if agentInfo == nil {
|
||||
return nil, errorx.New(errno.ErrAgentNotExists)
|
||||
}
|
||||
|
||||
return agentInfo, nil
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ func (rd *AgentRuntime) GetHistory() []*msgEntity.Message {
|
||||
|
||||
func (art *AgentRuntime) Run(ctx context.Context) (err error) {
|
||||
|
||||
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft, art.GetRunMeta().ConnectorID)
|
||||
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,156 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./oauth_repository.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=./oauth_repository.go -package=mock_plugin_oauth -destination=./mock/mock_oauth_repository.go
|
||||
//
|
||||
|
||||
// Package mock_plugin_oauth is a generated GoMock package.
|
||||
package mock_plugin_oauth
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
entity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockOAuthRepository is a mock of OAuthRepository interface.
|
||||
type MockOAuthRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockOAuthRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockOAuthRepositoryMockRecorder is the mock recorder for MockOAuthRepository.
|
||||
type MockOAuthRepositoryMockRecorder struct {
|
||||
mock *MockOAuthRepository
|
||||
}
|
||||
|
||||
// NewMockOAuthRepository creates a new mock instance.
|
||||
func NewMockOAuthRepository(ctrl *gomock.Controller) *MockOAuthRepository {
|
||||
mock := &MockOAuthRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockOAuthRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockOAuthRepository) EXPECT() *MockOAuthRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BatchDeleteAuthorizationCodeByIDs mocks base method.
|
||||
func (m *MockOAuthRepository) BatchDeleteAuthorizationCodeByIDs(ctx context.Context, ids []int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchDeleteAuthorizationCodeByIDs", ctx, ids)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchDeleteAuthorizationCodeByIDs indicates an expected call of BatchDeleteAuthorizationCodeByIDs.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) BatchDeleteAuthorizationCodeByIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchDeleteAuthorizationCodeByIDs", reflect.TypeOf((*MockOAuthRepository)(nil).BatchDeleteAuthorizationCodeByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// DeleteAuthorizationCode mocks base method.
|
||||
func (m *MockOAuthRepository) DeleteAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAuthorizationCode", ctx, meta)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAuthorizationCode indicates an expected call of DeleteAuthorizationCode.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) DeleteAuthorizationCode(ctx, meta any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthorizationCode", reflect.TypeOf((*MockOAuthRepository)(nil).DeleteAuthorizationCode), ctx, meta)
|
||||
}
|
||||
|
||||
// DeleteExpiredAuthorizationCodeTokens mocks base method.
|
||||
func (m *MockOAuthRepository) DeleteExpiredAuthorizationCodeTokens(ctx context.Context, expireAt int64, limit int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteExpiredAuthorizationCodeTokens", ctx, expireAt, limit)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteExpiredAuthorizationCodeTokens indicates an expected call of DeleteExpiredAuthorizationCodeTokens.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) DeleteExpiredAuthorizationCodeTokens(ctx, expireAt, limit any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiredAuthorizationCodeTokens", reflect.TypeOf((*MockOAuthRepository)(nil).DeleteExpiredAuthorizationCodeTokens), ctx, expireAt, limit)
|
||||
}
|
||||
|
||||
// DeleteInactiveAuthorizationCodeTokens mocks base method.
|
||||
func (m *MockOAuthRepository) DeleteInactiveAuthorizationCodeTokens(ctx context.Context, lastActiveAt int64, limit int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteInactiveAuthorizationCodeTokens", ctx, lastActiveAt, limit)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteInactiveAuthorizationCodeTokens indicates an expected call of DeleteInactiveAuthorizationCodeTokens.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) DeleteInactiveAuthorizationCodeTokens(ctx, lastActiveAt, limit any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteInactiveAuthorizationCodeTokens", reflect.TypeOf((*MockOAuthRepository)(nil).DeleteInactiveAuthorizationCodeTokens), ctx, lastActiveAt, limit)
|
||||
}
|
||||
|
||||
// GetAuthorizationCode mocks base method.
|
||||
func (m *MockOAuthRepository) GetAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (*entity.AuthorizationCodeInfo, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizationCode", ctx, meta)
|
||||
ret0, _ := ret[0].(*entity.AuthorizationCodeInfo)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetAuthorizationCode indicates an expected call of GetAuthorizationCode.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) GetAuthorizationCode(ctx, meta any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationCode", reflect.TypeOf((*MockOAuthRepository)(nil).GetAuthorizationCode), ctx, meta)
|
||||
}
|
||||
|
||||
// GetAuthorizationCodeRefreshTokens mocks base method.
|
||||
func (m *MockOAuthRepository) GetAuthorizationCodeRefreshTokens(ctx context.Context, nextRefreshAt int64, limit int) ([]*entity.AuthorizationCodeInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizationCodeRefreshTokens", ctx, nextRefreshAt, limit)
|
||||
ret0, _ := ret[0].([]*entity.AuthorizationCodeInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthorizationCodeRefreshTokens indicates an expected call of GetAuthorizationCodeRefreshTokens.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) GetAuthorizationCodeRefreshTokens(ctx, nextRefreshAt, limit any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationCodeRefreshTokens", reflect.TypeOf((*MockOAuthRepository)(nil).GetAuthorizationCodeRefreshTokens), ctx, nextRefreshAt, limit)
|
||||
}
|
||||
|
||||
// UpdateAuthorizationCodeLastActiveAt mocks base method.
|
||||
func (m *MockOAuthRepository) UpdateAuthorizationCodeLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAuthorizationCodeLastActiveAt", ctx, meta, lastActiveAtMs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAuthorizationCodeLastActiveAt indicates an expected call of UpdateAuthorizationCodeLastActiveAt.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) UpdateAuthorizationCodeLastActiveAt(ctx, meta, lastActiveAtMs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAuthorizationCodeLastActiveAt", reflect.TypeOf((*MockOAuthRepository)(nil).UpdateAuthorizationCodeLastActiveAt), ctx, meta, lastActiveAtMs)
|
||||
}
|
||||
|
||||
// UpsertAuthorizationCode mocks base method.
|
||||
func (m *MockOAuthRepository) UpsertAuthorizationCode(ctx context.Context, info *entity.AuthorizationCodeInfo) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertAuthorizationCode", ctx, info)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertAuthorizationCode indicates an expected call of UpsertAuthorizationCode.
|
||||
func (mr *MockOAuthRepositoryMockRecorder) UpsertAuthorizationCode(ctx, info any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAuthorizationCode", reflect.TypeOf((*MockOAuthRepository)(nil).UpsertAuthorizationCode), ctx, info)
|
||||
}
|
||||
@ -22,7 +22,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -source=./oauth_repository.go -package=mock_plugin_oauth -destination=./mock/mock_oauth_repository.go
|
||||
type OAuthRepository interface {
|
||||
GetAuthorizationCode(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error)
|
||||
UpsertAuthorizationCode(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error)
|
||||
|
||||
@ -43,7 +43,6 @@ import (
|
||||
var (
|
||||
initOnce = sync.Once{}
|
||||
lastActiveInterval = 15 * 24 * time.Hour
|
||||
failedCache = sync.Map{}
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) processOAuthAccessToken(ctx context.Context) {
|
||||
@ -124,69 +123,60 @@ func (p *pluginServiceImpl) refreshToken(ctx context.Context, info *entity.Autho
|
||||
|
||||
source := config.TokenSource(ctx, token)
|
||||
|
||||
token, err := source.Token()
|
||||
var (
|
||||
err error
|
||||
newToken *oauth2.Token
|
||||
)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
newToken, err = source.Token()
|
||||
if err == nil {
|
||||
token = newToken
|
||||
break
|
||||
}
|
||||
<-time.After(time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "refreshToken failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
p.refreshTokenFailedHandler(ctx, info.RecordID, err)
|
||||
logs.CtxInfof(ctx, "refreshToken failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{info.RecordID})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var expiredAtMS int64
|
||||
if !token.Expiry.IsZero() && token.Expiry.After(time.Now()) {
|
||||
expiredAtMS = token.Expiry.UnixMilli()
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
var expiredAtMS int64
|
||||
if !token.Expiry.IsZero() && token.Expiry.After(time.Now()) {
|
||||
expiredAtMS = token.Expiry.UnixMilli()
|
||||
}
|
||||
|
||||
err = p.oauthRepo.UpsertAuthorizationCode(ctx, &entity.AuthorizationCodeInfo{
|
||||
Meta: &entity.AuthorizationCodeMeta{
|
||||
UserID: info.Meta.UserID,
|
||||
PluginID: info.Meta.PluginID,
|
||||
IsDraft: info.Meta.IsDraft,
|
||||
},
|
||||
Config: info.Config,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenExpiredAtMS: expiredAtMS,
|
||||
NextTokenRefreshAtMS: ptr.Of(getNextTokenRefreshAtMS(expiredAtMS)),
|
||||
})
|
||||
err = p.oauthRepo.UpsertAuthorizationCode(ctx, &entity.AuthorizationCodeInfo{
|
||||
Meta: &entity.AuthorizationCodeMeta{
|
||||
UserID: info.Meta.UserID,
|
||||
PluginID: info.Meta.PluginID,
|
||||
IsDraft: info.Meta.IsDraft,
|
||||
},
|
||||
Config: info.Config,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenExpiredAtMS: expiredAtMS,
|
||||
NextTokenRefreshAtMS: ptr.Of(getNextTokenRefreshAtMS(expiredAtMS)),
|
||||
})
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
<-time.After(time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
logs.CtxInfof(ctx, "UpsertAuthorizationCode failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
p.refreshTokenFailedHandler(ctx, info.RecordID, err)
|
||||
return
|
||||
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{info.RecordID})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) refreshTokenFailedHandler(ctx context.Context, recordID int64, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
const maxFailedTimes = 5
|
||||
|
||||
failedTimes, ok := failedCache.Load(recordID)
|
||||
if !ok {
|
||||
failedCache.Store(recordID, 1)
|
||||
return
|
||||
}
|
||||
|
||||
failedTimes_ := failedTimes.(int) + 1
|
||||
failedCache.Store(recordID, failedTimes_)
|
||||
|
||||
if failedTimes_ < maxFailedTimes {
|
||||
return
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "refreshToken exceeds max failed times, recordID=%d, err=%v", recordID, err)
|
||||
|
||||
failedCache.Delete(recordID)
|
||||
|
||||
err_ := p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{recordID})
|
||||
if err_ != nil {
|
||||
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", recordID, err_)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (accessToken string, err error) {
|
||||
switch oa.OAuthMode {
|
||||
case model.AuthzSubTypeOfOAuthAuthorizationCode:
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository/mock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
type pluginOAuthSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
ctx context.Context
|
||||
|
||||
mockOauthRepo *mock_plugin_oauth.MockOAuthRepository
|
||||
}
|
||||
|
||||
func TestPluginOAuthSuite(t *testing.T) {
|
||||
suite.Run(t, &pluginOAuthSuite{})
|
||||
}
|
||||
|
||||
func (s *pluginOAuthSuite) SetupSuite() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.mockOauthRepo = mock_plugin_oauth.NewMockOAuthRepository(s.ctrl)
|
||||
}
|
||||
|
||||
func (s *pluginOAuthSuite) TearDownSuite() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func (s *pluginOAuthSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
}
|
||||
|
||||
func (s *pluginOAuthSuite) TearDownTest() {
|
||||
|
||||
}
|
||||
|
||||
func (s *pluginOAuthSuite) TestRefreshTokenFailedHandler() {
|
||||
mockRecordID := int64(123)
|
||||
mockErr := fmt.Errorf("mock error")
|
||||
mockSVC := &pluginServiceImpl{
|
||||
oauthRepo: s.mockOauthRepo,
|
||||
}
|
||||
|
||||
mockSVC.refreshTokenFailedHandler(s.ctx, mockRecordID, mockErr)
|
||||
failedTimes, ok := failedCache.Load(mockRecordID)
|
||||
assert.True(s.T(), ok)
|
||||
assert.Equal(s.T(), 1, failedTimes.(int))
|
||||
|
||||
for i := 2; i < 5; i++ {
|
||||
mockSVC.refreshTokenFailedHandler(s.ctx, mockRecordID, mockErr)
|
||||
failedTimes, ok = failedCache.Load(mockRecordID)
|
||||
assert.True(s.T(), ok)
|
||||
assert.Equal(s.T(), i, failedTimes.(int))
|
||||
}
|
||||
|
||||
s.mockOauthRepo.EXPECT().BatchDeleteAuthorizationCodeByIDs(gomock.Any(), gomock.Any()).
|
||||
Return(nil).Times(1)
|
||||
|
||||
mockSVC.refreshTokenFailedHandler(s.ctx, mockRecordID, mockErr)
|
||||
_, ok = failedCache.Load(mockRecordID)
|
||||
assert.False(s.T(), ok)
|
||||
}
|
||||
@ -18,6 +18,7 @@ package adaptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -27,8 +28,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@ -1,189 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": -222,
|
||||
"y": 48.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "176091",
|
||||
"name": "isSuccess",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "isSuccess"
|
||||
}
|
||||
],
|
||||
"terminatePlan": "returnVariables"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1027.5,
|
||||
"y": 46.510710144593645
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "user",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "role"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "123",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "content"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于创建消息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-创建消息.jpg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "创建消息",
|
||||
"title": "创建消息"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "message",
|
||||
"schema": [
|
||||
{
|
||||
"name": "messageId",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "role",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "contentType",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "content",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "176091",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 438.75,
|
||||
"y": 35.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "55"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "176091",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "176091",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -1,317 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": -212,
|
||||
"y": -22.069277108433766
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "140767",
|
||||
"name": "isSuccess",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "isSuccess"
|
||||
}
|
||||
],
|
||||
"terminatePlan": "returnVariables"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1027.5,
|
||||
"y": 46.510710144593645
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于创建会话",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "创建会话",
|
||||
"title": "创建会话"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "isExisted",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "conversationId",
|
||||
"type": "string"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100911",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 23.75,
|
||||
"y": 207.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "39"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "user",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "role"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "123",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "content"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于创建消息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-创建消息.jpg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "创建消息",
|
||||
"title": "创建消息"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "message",
|
||||
"schema": [
|
||||
{
|
||||
"name": "messageId",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "role",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "contentType",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "content",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "176091",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 272.75,
|
||||
"y": -173.27928348192006
|
||||
}
|
||||
},
|
||||
"type": "55"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "176091",
|
||||
"name": "message.messageId",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "messageId"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于删除消息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-删除消息.jpg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "删除消息",
|
||||
"title": "删除消息"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "140767",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 616.75,
|
||||
"y": 73.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "57"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "100911",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "140767",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "100911",
|
||||
"targetNodeID": "176091",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "176091",
|
||||
"targetNodeID": "140767",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -1,333 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": -212,
|
||||
"y": -22.069277108433766
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "191842",
|
||||
"name": "isSuccess",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "isSuccess"
|
||||
}
|
||||
],
|
||||
"terminatePlan": "returnVariables"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1027.5,
|
||||
"y": 46.510710144593645
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于创建会话",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "创建会话",
|
||||
"title": "创建会话"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "isExisted",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "conversationId",
|
||||
"type": "string"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100911",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 23.75,
|
||||
"y": 207.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "39"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "user",
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "role"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "123",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "content"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于创建消息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-创建消息.jpg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "创建消息",
|
||||
"title": "创建消息"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "message",
|
||||
"schema": [
|
||||
{
|
||||
"name": "messageId",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "role",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "contentType",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "content",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "176091",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 272.75,
|
||||
"y": -173.27928348192006
|
||||
}
|
||||
},
|
||||
"type": "55"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "176091",
|
||||
"name": "message.messageId",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "messageId"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "修改消息",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "newContent"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于修改消息的内容",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-修改消息.jpg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "修改消息",
|
||||
"title": "修改消息"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "191842",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 541.75,
|
||||
"y": 46.510710144593645
|
||||
}
|
||||
},
|
||||
"type": "56"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "100911",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "191842",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "100911",
|
||||
"targetNodeID": "176091",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "176091",
|
||||
"targetNodeID": "191842",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -1,228 +0,0 @@
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "USER_INPUT",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"defaultValue": "Default",
|
||||
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"trigger_parameters": []
|
||||
},
|
||||
"edges": null,
|
||||
"id": "100001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": -222,
|
||||
"y": 48.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "1"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "boolean",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "112184",
|
||||
"name": "isSuccess",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 3
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "isSuccess"
|
||||
}
|
||||
],
|
||||
"terminatePlan": "returnVariables"
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
}
|
||||
},
|
||||
"edges": null,
|
||||
"id": "900001",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1092.75,
|
||||
"y": 46.510710144593645
|
||||
}
|
||||
},
|
||||
"type": "2"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "123",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "messageId"
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": "123",
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "literal"
|
||||
}
|
||||
},
|
||||
"name": "newContent"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于修改消息的内容",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-修改消息.jpg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "修改消息",
|
||||
"title": "修改消息"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "112184",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 598,
|
||||
"y": 35.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "56"
|
||||
},
|
||||
{
|
||||
"blocks": [],
|
||||
"data": {
|
||||
"inputs": {
|
||||
"inputParameters": [
|
||||
{
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"content": {
|
||||
"blockID": "100001",
|
||||
"name": "CONVERSATION_NAME",
|
||||
"source": "block-output"
|
||||
},
|
||||
"rawMeta": {
|
||||
"type": 1
|
||||
},
|
||||
"type": "ref"
|
||||
}
|
||||
},
|
||||
"name": "conversationName"
|
||||
}
|
||||
]
|
||||
},
|
||||
"nodeMeta": {
|
||||
"description": "用于创建会话",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
|
||||
"mainColor": "#F2B600",
|
||||
"subTitle": "创建会话",
|
||||
"title": "创建会话"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "isSuccess",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "isExisted",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"name": "conversationId",
|
||||
"type": "string"
|
||||
}
|
||||
]
|
||||
},
|
||||
"edges": null,
|
||||
"id": "146209",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 188,
|
||||
"y": 35.72071651807994
|
||||
}
|
||||
},
|
||||
"type": "39"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"sourceNodeID": "100001",
|
||||
"targetNodeID": "146209",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "112184",
|
||||
"targetNodeID": "900001",
|
||||
"sourcePortID": ""
|
||||
},
|
||||
{
|
||||
"sourceNodeID": "146209",
|
||||
"targetNodeID": "112184",
|
||||
"sourcePortID": ""
|
||||
}
|
||||
],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}
|
||||
@ -31,11 +31,10 @@ import (
|
||||
model2 "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
|
||||
@ -99,22 +99,20 @@ func (c *ClearConversationHistory) Invoke(ctx context.Context, in map[string]any
|
||||
}
|
||||
var conversationID int64
|
||||
if existed {
|
||||
var sc *entity.StaticConversation
|
||||
sc, existed, err = wf.GetRepository().GetStaticConversationByTemplateID(ctx, env, userID, connectorID, t.TemplateID)
|
||||
ret, existed, err := wf.GetRepository().GetStaticConversationByTemplateID(ctx, env, userID, connectorID, t.TemplateID)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrConversationNodeOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
||||
}
|
||||
if existed {
|
||||
conversationID = sc.ConversationID
|
||||
conversationID = ret.ConversationID
|
||||
}
|
||||
} else {
|
||||
var dc *entity.DynamicConversation
|
||||
dc, existed, err = wf.GetRepository().GetDynamicConversationByName(ctx, env, *appID, connectorID, userID, conversationName)
|
||||
ret, existed, err := wf.GetRepository().GetDynamicConversationByName(ctx, env, *appID, connectorID, userID, conversationName)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrConversationNodeOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
||||
}
|
||||
if existed {
|
||||
conversationID = dc.ConversationID
|
||||
conversationID = ret.ConversationID
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
|
||||
@ -20,7 +20,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
|
||||
@ -28,7 +27,6 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
einoSchema "github.com/cloudwego/eino/schema"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
|
||||
@ -20,9 +20,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"strconv"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
|
||||
@ -21,8 +21,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
@ -34,8 +34,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']")
|
||||
|
||||
type CustomSQLConfig struct {
|
||||
DatabaseInfoID int64
|
||||
SQLTemplate string
|
||||
@ -113,60 +111,47 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?"))
|
||||
templateSQL := ""
|
||||
if len(templateParts) > 0 {
|
||||
if len(templateParts) == 0 {
|
||||
templateSQL = templateParts[0].Value
|
||||
templateParts := nodes.ParseTemplate(c.sqlTemplate)
|
||||
sqlParams := make([]database.SQLParam, 0, len(templateParts))
|
||||
var nilError = errors.New("field is nil")
|
||||
for _, templatePart := range templateParts {
|
||||
if !templatePart.IsVariable {
|
||||
templateSQL += templatePart.Value
|
||||
continue
|
||||
}
|
||||
|
||||
templateSQL += "?"
|
||||
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
|
||||
return "", nilError
|
||||
}),
|
||||
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
|
||||
b := val.(bool)
|
||||
if b {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, nilError) {
|
||||
return nil, err
|
||||
}
|
||||
sqlParams = append(sqlParams, database.SQLParam{
|
||||
IsNull: true,
|
||||
})
|
||||
} else {
|
||||
for _, templatePart := range templateParts {
|
||||
if !templatePart.IsVariable {
|
||||
templateSQL += templatePart.Value
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
|
||||
b := val.(bool)
|
||||
if b {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
templateSQL += val
|
||||
|
||||
}
|
||||
sqlParams = append(sqlParams, database.SQLParam{
|
||||
Value: val,
|
||||
IsNull: false,
|
||||
})
|
||||
}
|
||||
|
||||
} else {
|
||||
return nil, fmt.Errorf("parse template invalid")
|
||||
}
|
||||
|
||||
sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1)
|
||||
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings))
|
||||
for _, s := range sqlParamStrings {
|
||||
parts := nodes.ParseTemplate(s)
|
||||
for _, part := range parts {
|
||||
if part.IsVariable {
|
||||
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
|
||||
b := val.(bool)
|
||||
if b {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlParams = append(sqlParams, database.SQLParam{
|
||||
Value: val,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// replace sql template '?' to ?
|
||||
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
|
||||
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
|
||||
req.SQL = templateSQL
|
||||
req.Params = sqlParams
|
||||
response, err := crossdatabase.DefaultSVC().Execute(ctx, req)
|
||||
|
||||
@ -61,12 +61,12 @@ func TestCustomSQL_Execute(t *testing.T) {
|
||||
validate: func(req *database.CustomSQLRequest) {
|
||||
assert.Equal(t, int64(111), req.DatabaseInfoID)
|
||||
ps := []database.SQLParam{
|
||||
{Value: "v1_value"},
|
||||
{Value: "v2_value"},
|
||||
{Value: "v3_value"},
|
||||
{Value: "1"},
|
||||
}
|
||||
assert.Equal(t, ps, req.Params)
|
||||
assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ? and v4 = ?", req.SQL)
|
||||
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
|
||||
},
|
||||
}
|
||||
|
||||
@ -86,7 +86,7 @@ func TestCustomSQL_Execute(t *testing.T) {
|
||||
|
||||
cfg := &CustomSQLConfig{
|
||||
DatabaseInfoID: 111,
|
||||
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}` and v4 = '{{v4}}'",
|
||||
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
|
||||
}
|
||||
|
||||
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
@ -104,7 +104,6 @@ func TestCustomSQL_Execute(t *testing.T) {
|
||||
"v1": "v1_value",
|
||||
"v2": "v2_value",
|
||||
"v3": "v3_value",
|
||||
"v4": true,
|
||||
})
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
@ -19,7 +19,6 @@ package intentdetector
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
|
||||
@ -19,8 +19,6 @@ package knowledge
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
@ -29,12 +27,9 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type DeleterConfig struct {
|
||||
KnowledgeID int64
|
||||
}
|
||||
type DeleterConfig struct{}
|
||||
|
||||
func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
@ -44,18 +39,6 @@ func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOpt
|
||||
Configs: d,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
datasetListInfoParam := inputs.DatasetParam[0]
|
||||
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
if len(datasetIDs) == 0 {
|
||||
return nil, fmt.Errorf("dataset ids is required")
|
||||
}
|
||||
knowledgeID, err := cast.ToInt64E(datasetIDs[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.KnowledgeID = knowledgeID
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -68,29 +51,19 @@ func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOpt
|
||||
}
|
||||
|
||||
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Deleter{
|
||||
KnowledgeID: d.KnowledgeID,
|
||||
}, nil
|
||||
return &Deleter{}, nil
|
||||
}
|
||||
|
||||
type Deleter struct {
|
||||
KnowledgeID int64
|
||||
}
|
||||
type Deleter struct{}
|
||||
|
||||
func (d *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
documentID, ok := input["documentID"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("documentID is required and must be a string")
|
||||
}
|
||||
|
||||
docID, err := strconv.ParseInt(documentID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid document id: %s", documentID)
|
||||
}
|
||||
|
||||
req := &knowledge.DeleteDocumentRequest{
|
||||
DocumentID: docID,
|
||||
KnowledgeID: d.KnowledgeID,
|
||||
DocumentID: documentID,
|
||||
}
|
||||
|
||||
response, err := crossknowledge.DefaultSVC().Delete(ctx, req)
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
||||
@ -19,7 +19,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
|
||||
|
||||
@ -29,7 +29,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
|
||||
@ -1,158 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBatchSize = 100
|
||||
defaultTopK = 10
|
||||
defaultVectorDimension = 2048
|
||||
defaultVectorMemoryLimitPercentage = 30
|
||||
defaultMaxOpenConns = 100
|
||||
defaultMaxIdleConns = 10
|
||||
defaultConnMaxLifetime = 3600
|
||||
defaultConnMaxIdleTime = 1800
|
||||
defaultCacheTTL = 300
|
||||
defaultConnTimeout = 30
|
||||
defaultMaxRetries = 3
|
||||
defaultRetryDelay = 1
|
||||
maxVectorDimension = 4096
|
||||
maxCollectionNameLength = 255
|
||||
maxSQLIdentifierLength = 64
|
||||
maxContentLength = 65535
|
||||
maxBatchSize = 1000
|
||||
|
||||
enableCacheDefault = true
|
||||
enableMetricsDefault = true
|
||||
enableSlowQueryLogDefault = true
|
||||
slowQueryThreshold = 1000
|
||||
|
||||
ErrCodeInvalidConfig = "INVALID_CONFIG"
|
||||
ErrCodeConnectionFailed = "CONNECTION_FAILED"
|
||||
ErrCodeQueryTimeout = "QUERY_TIMEOUT"
|
||||
ErrCodeVectorDimensionMismatch = "VECTOR_DIMENSION_MISMATCH"
|
||||
ErrCodeCollectionNotFound = "COLLECTION_NOT_FOUND"
|
||||
ErrCodeDuplicateCollection = "DUPLICATE_COLLECTION"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
|
||||
VectorDimension int
|
||||
MetricType string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
VectorMemoryLimitPercentage int
|
||||
BatchSize int
|
||||
|
||||
EnableCache bool
|
||||
CacheTTL time.Duration
|
||||
EnableMetrics bool
|
||||
EnableSlowQueryLog bool
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
ConnTimeout time.Duration
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Host: getEnv("OCEANBASE_HOST", "localhost"),
|
||||
Port: getEnvAsInt("OCEANBASE_PORT", 2881),
|
||||
User: getEnv("OCEANBASE_USER", "root"),
|
||||
Password: getEnv("OCEANBASE_PASSWORD", ""),
|
||||
Database: getEnv("OCEANBASE_DATABASE", "test"),
|
||||
VectorDimension: getVectorDimension(),
|
||||
MetricType: "cosine",
|
||||
MaxOpenConns: getEnvAsInt("OCEANBASE_MAX_OPEN_CONNS", defaultMaxOpenConns),
|
||||
MaxIdleConns: getEnvAsInt("OCEANBASE_MAX_IDLE_CONNS", defaultMaxIdleConns),
|
||||
ConnMaxLifetime: time.Duration(getEnvAsInt("OCEANBASE_CONN_MAX_LIFETIME", defaultConnMaxLifetime)) * time.Second,
|
||||
ConnMaxIdleTime: time.Duration(getEnvAsInt("OCEANBASE_CONN_MAX_IDLE_TIME", defaultConnMaxIdleTime)) * time.Second,
|
||||
VectorMemoryLimitPercentage: getEnvAsInt("OCEANBASE_VECTOR_MEMORY_LIMIT_PERCENTAGE", defaultVectorMemoryLimitPercentage),
|
||||
BatchSize: getEnvAsInt("OCEANBASE_BATCH_SIZE", defaultBatchSize),
|
||||
EnableCache: getEnvAsBool("OCEANBASE_ENABLE_CACHE", enableCacheDefault),
|
||||
CacheTTL: time.Duration(getEnvAsInt("OCEANBASE_CACHE_TTL", defaultCacheTTL)) * time.Second,
|
||||
EnableMetrics: getEnvAsBool("OCEANBASE_ENABLE_METRICS", enableMetricsDefault),
|
||||
EnableSlowQueryLog: getEnvAsBool("OCEANBASE_ENABLE_SLOW_QUERY_LOG", enableSlowQueryLogDefault),
|
||||
MaxRetries: getEnvAsInt("OCEANBASE_MAX_RETRIES", defaultMaxRetries),
|
||||
RetryDelay: time.Duration(getEnvAsInt("OCEANBASE_RETRY_DELAY", defaultRetryDelay)) * time.Second,
|
||||
ConnTimeout: time.Duration(getEnvAsInt("OCEANBASE_CONN_TIMEOUT", defaultConnTimeout)) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.Host == "" {
|
||||
return fmt.Errorf("host cannot be empty")
|
||||
}
|
||||
if c.Port <= 0 || c.Port > 65535 {
|
||||
return fmt.Errorf("port must be between 1 and 65535")
|
||||
}
|
||||
if c.User == "" {
|
||||
return fmt.Errorf("user cannot be empty")
|
||||
}
|
||||
if c.Database == "" {
|
||||
return fmt.Errorf("database cannot be empty")
|
||||
}
|
||||
if c.VectorDimension <= 0 || c.VectorDimension > maxVectorDimension {
|
||||
return fmt.Errorf("vector dimension must be between 1 and %d", maxVectorDimension)
|
||||
}
|
||||
if c.BatchSize <= 0 || c.BatchSize > maxBatchSize {
|
||||
return fmt.Errorf("batch size must be between 1 and %d", maxBatchSize)
|
||||
}
|
||||
if c.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("max open connections must be positive")
|
||||
}
|
||||
if c.MaxIdleConns <= 0 || c.MaxIdleConns > c.MaxOpenConns {
|
||||
return fmt.Errorf("max idle connections must be positive and not greater than max open connections")
|
||||
}
|
||||
if c.CacheTTL <= 0 {
|
||||
return fmt.Errorf("cache TTL must be positive")
|
||||
}
|
||||
if c.MaxRetries < 0 {
|
||||
return fmt.Errorf("max retries cannot be negative")
|
||||
}
|
||||
if c.RetryDelay < 0 {
|
||||
return fmt.Errorf("retry delay cannot be negative")
|
||||
}
|
||||
if c.ConnTimeout <= 0 {
|
||||
return fmt.Errorf("connection timeout must be positive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getVectorDimension() int {
|
||||
if dims := getEnvAsInt("ARK_EMBEDDING_DIMS", 0); dims > 0 {
|
||||
return dims
|
||||
}
|
||||
if dims := getEnvAsInt("OPENAI_EMBEDDING_DIMS", 0); dims > 0 {
|
||||
return dims
|
||||
}
|
||||
return defaultVectorDimension
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,239 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TableName(collectionName string) string {
|
||||
cleanName := strings.Map(func(r rune) rune {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' {
|
||||
return r
|
||||
}
|
||||
return '_'
|
||||
}, collectionName)
|
||||
return fmt.Sprintf("vector_%s", strings.ToLower(cleanName))
|
||||
}
|
||||
|
||||
func ExtractContent(doc *schema.Document) string {
|
||||
if doc.Content != "" {
|
||||
return strings.TrimSpace(doc.Content)
|
||||
}
|
||||
if doc.MetaData != nil {
|
||||
if content, ok := doc.MetaData["content"].(string); ok && content != "" {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
if text, ok := doc.MetaData["text"].(string); ok && text != "" {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func BuildMetadata(doc *schema.Document) map[string]interface{} {
|
||||
metadata := make(map[string]interface{})
|
||||
if doc.MetaData != nil {
|
||||
for k, v := range doc.MetaData {
|
||||
metadata[k] = v
|
||||
}
|
||||
}
|
||||
metadata["document_id"] = doc.ID
|
||||
metadata["content"] = doc.Content
|
||||
metadata["content_length"] = len(doc.Content)
|
||||
return metadata
|
||||
}
|
||||
|
||||
func MetadataToJSON(metadata map[string]interface{}) (string, error) {
|
||||
if metadata == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
func JSONToMetadata(jsonStr string) (map[string]interface{}, error) {
|
||||
if jsonStr == "" {
|
||||
return make(map[string]interface{}), nil
|
||||
}
|
||||
var metadata map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal metadata: %w", err)
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func ValidateCollectionName(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("collection name cannot be empty")
|
||||
}
|
||||
if len(name) > maxCollectionNameLength {
|
||||
return fmt.Errorf("collection name too long (max %d characters)", maxCollectionNameLength)
|
||||
}
|
||||
|
||||
if len(name) > 0 && unicode.IsDigit(rune(name[0])) {
|
||||
return fmt.Errorf("collection name cannot start with a digit")
|
||||
}
|
||||
|
||||
for _, r := range name {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-') {
|
||||
return fmt.Errorf("collection name contains invalid character: %c", r)
|
||||
}
|
||||
}
|
||||
|
||||
if isReservedWord(name) {
|
||||
return fmt.Errorf("collection name is a reserved word: %s", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func BuildInClause(values []string) string {
|
||||
if len(values) == 0 {
|
||||
return "()"
|
||||
}
|
||||
quoted := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
quoted[i] = fmt.Sprintf("'%s'", v)
|
||||
}
|
||||
return fmt.Sprintf("(%s)", strings.Join(quoted, ","))
|
||||
}
|
||||
|
||||
func ConvertToFloat32(f64 []float64) []float32 {
|
||||
f32 := make([]float32, len(f64))
|
||||
for i, v := range f64 {
|
||||
f32[i] = float32(v)
|
||||
}
|
||||
return f32
|
||||
}
|
||||
|
||||
func ConvertToFloat64(f32 []float32) []float64 {
|
||||
f64 := make([]float64, len(f32))
|
||||
for i, v := range f32 {
|
||||
f64[i] = float64(v)
|
||||
}
|
||||
return f64
|
||||
}
|
||||
|
||||
func SanitizeString(s string) string {
|
||||
s = strings.Map(func(r rune) rune {
|
||||
if r < 32 || r == 127 {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, s)
|
||||
|
||||
s = strings.Join(strings.Fields(s), " ")
|
||||
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func TruncateString(s string, maxLength int) string {
|
||||
if len(s) <= maxLength {
|
||||
return s
|
||||
}
|
||||
return s[:maxLength-3] + "..."
|
||||
}
|
||||
|
||||
func IsValidVector(vector []float32) error {
|
||||
if len(vector) == 0 {
|
||||
return fmt.Errorf("vector cannot be empty")
|
||||
}
|
||||
if len(vector) > maxVectorDimension {
|
||||
return fmt.Errorf("vector dimension too large (max %d)", maxVectorDimension)
|
||||
}
|
||||
|
||||
for i, v := range vector {
|
||||
if v != v { // NaN check
|
||||
return fmt.Errorf("vector contains NaN at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeVector(vector []float32) []float32 {
|
||||
if len(vector) == 0 {
|
||||
return vector
|
||||
}
|
||||
|
||||
var sum float32
|
||||
for _, v := range vector {
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
if sum == 0 {
|
||||
return vector
|
||||
}
|
||||
|
||||
norm := float32(1.0 / math.Sqrt(float64(sum)))
|
||||
normalized := make([]float32, len(vector))
|
||||
for i, v := range vector {
|
||||
normalized[i] = v * norm
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
var reservedWords = map[string]bool{
|
||||
"select": true, "from": true, "where": true, "insert": true, "update": true,
|
||||
"delete": true, "drop": true, "create": true, "alter": true, "table": true,
|
||||
"index": true, "primary": true, "foreign": true, "key": true, "constraint": true,
|
||||
"order": true, "by": true, "group": true, "having": true, "union": true,
|
||||
"all": true, "distinct": true, "as": true, "in": true, "between": true,
|
||||
"like": true, "is": true, "null": true, "not": true, "and": true, "or": true,
|
||||
"vector": true, "embedding": true, "collection": true,
|
||||
}
|
||||
|
||||
func isReservedWord(name string) bool {
|
||||
return reservedWords[strings.ToLower(name)]
|
||||
}
|
||||
|
||||
func GenerateTableName(collectionName string, suffix string) string {
|
||||
baseName := TableName(collectionName)
|
||||
if suffix != "" {
|
||||
return fmt.Sprintf("%s_%s", baseName, suffix)
|
||||
}
|
||||
return baseName
|
||||
}
|
||||
|
||||
func ValidateSQLIdentifier(identifier string) error {
|
||||
if identifier == "" {
|
||||
return fmt.Errorf("SQL identifier cannot be empty")
|
||||
}
|
||||
|
||||
if len(identifier) > 64 {
|
||||
return fmt.Errorf("SQL identifier too long (max 64 characters)")
|
||||
}
|
||||
|
||||
matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, identifier)
|
||||
if !matched {
|
||||
return fmt.Errorf("SQL identifier format invalid: %s", identifier)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,80 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type Factory struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewFactory(config *Config) *Factory {
|
||||
return &Factory{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Factory) CreateManager(ctx context.Context, embedder embedding.Embedder) (searchstore.Manager, error) {
|
||||
if err := f.config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
f.config.User, f.config.Password, f.config.Host, f.config.Port, f.config.Database)
|
||||
|
||||
client, err := oceanbase.NewOceanBaseClient(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OceanBase client: %w", err)
|
||||
}
|
||||
|
||||
managerConfig := &ManagerConfig{
|
||||
Client: client,
|
||||
Embedding: embedder,
|
||||
BatchSize: f.config.BatchSize,
|
||||
EnableCache: f.config.EnableCache,
|
||||
CacheTTL: f.config.CacheTTL,
|
||||
MaxConnections: f.config.MaxOpenConns,
|
||||
ConnTimeout: f.config.ConnTimeout,
|
||||
}
|
||||
|
||||
manager, err := NewManager(managerConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OceanBase manager: %w", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Created OceanBase vector store manager with config: %s:%d/%s (dimension: %d, cache: %v, batchSize: %d)",
|
||||
f.config.Host, f.config.Port, f.config.Database, f.config.VectorDimension,
|
||||
f.config.EnableCache, f.config.BatchSize)
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func (f *Factory) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (f *Factory) GetConfig() *Config {
|
||||
return f.config
|
||||
}
|
||||
@ -1,367 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Client *oceanbase.OceanBaseClient
|
||||
Embedding embedding.Embedder
|
||||
BatchSize int
|
||||
|
||||
EnableCache bool
|
||||
CacheTTL time.Duration
|
||||
MaxConnections int
|
||||
ConnTimeout time.Duration
|
||||
|
||||
EnableConnectionPool bool
|
||||
PoolMaxIdle int
|
||||
PoolMaxLifetime time.Duration
|
||||
QueryTimeout time.Duration
|
||||
MaxRetries int // optional: default 3
|
||||
RetryDelay time.Duration // optional: default 1s
|
||||
}
|
||||
|
||||
// Create an OceanBase vector storage manager
|
||||
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
||||
if config.Client == nil {
|
||||
return nil, fmt.Errorf("[NewManager] oceanbase client not provided")
|
||||
}
|
||||
if config.Embedding == nil {
|
||||
return nil, fmt.Errorf("[NewManager] oceanbase embedder not provided")
|
||||
}
|
||||
|
||||
if config.BatchSize == 0 {
|
||||
config.BatchSize = defaultBatchSize
|
||||
}
|
||||
|
||||
if config.CacheTTL == 0 {
|
||||
config.CacheTTL = 5 * time.Minute
|
||||
}
|
||||
if config.MaxConnections == 0 {
|
||||
config.MaxConnections = defaultMaxOpenConns
|
||||
}
|
||||
if config.ConnTimeout == 0 {
|
||||
config.ConnTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
|
||||
if config.PoolMaxIdle == 0 {
|
||||
config.PoolMaxIdle = 10
|
||||
}
|
||||
if config.PoolMaxLifetime == 0 {
|
||||
config.PoolMaxLifetime = 1 * time.Hour
|
||||
}
|
||||
if config.QueryTimeout == 0 {
|
||||
config.QueryTimeout = 30 * time.Second
|
||||
}
|
||||
if config.MaxRetries == 0 {
|
||||
config.MaxRetries = 3
|
||||
}
|
||||
if config.RetryDelay == 0 {
|
||||
config.RetryDelay = 1 * time.Second
|
||||
}
|
||||
|
||||
manager := &oceanbaseManager{
|
||||
config: config,
|
||||
cache: make(map[string]*cachedSearchStore),
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
if config.EnableCache {
|
||||
go manager.startCacheCleaner()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.ConnTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := config.Client.InitDatabase(ctx); err != nil {
|
||||
logs.CtxWarnf(ctx, "Failed to initialize OceanBase database: %v", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Created OceanBase vector store manager with cache=%v, batchSize=%d, pool=%v",
|
||||
config.EnableCache, config.BatchSize, config.EnableConnectionPool)
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
type oceanbaseManager struct {
|
||||
config *ManagerConfig
|
||||
cache map[string]*cachedSearchStore
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
// cachedSearchStore 缓存的搜索存储
|
||||
type cachedSearchStore struct {
|
||||
store searchstore.SearchStore
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if err := ValidateCollectionName(req.CollectionName); err != nil {
|
||||
return fmt.Errorf("[Create] invalid collection name: %w", err)
|
||||
}
|
||||
|
||||
tableName := TableName(req.CollectionName)
|
||||
|
||||
dimension := m.getVectorDimension()
|
||||
|
||||
logs.CtxInfof(ctx, "[Create] Using dimension: %d for collection: %s", dimension, req.CollectionName)
|
||||
|
||||
if err := m.config.Client.CreateCollection(ctx, req.CollectionName, dimension); err != nil {
|
||||
return fmt.Errorf("[Create] create vector collection failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.recordCollection(ctx, req.CollectionName, tableName); err != nil {
|
||||
logs.CtxWarnf(ctx, "Failed to record collection: %v", err)
|
||||
}
|
||||
|
||||
m.clearCache(req.CollectionName)
|
||||
|
||||
logs.CtxInfof(ctx, "Created OceanBase collection: %s (table: %s)", req.CollectionName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
|
||||
if err := ValidateCollectionName(req.CollectionName); err != nil {
|
||||
return fmt.Errorf("[Drop] invalid collection name: %w", err)
|
||||
}
|
||||
|
||||
tableName := TableName(req.CollectionName)
|
||||
|
||||
if err := m.config.Client.DropCollection(ctx, req.CollectionName); err != nil {
|
||||
return fmt.Errorf("[Drop] drop collection failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.removeCollection(ctx, req.CollectionName); err != nil {
|
||||
logs.CtxWarnf(ctx, "Failed to remove collection record: %v", err)
|
||||
}
|
||||
|
||||
m.clearCache(req.CollectionName)
|
||||
|
||||
logs.CtxInfof(ctx, "Deleted OceanBase collection: %s (table: %s)", req.CollectionName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
if err := ValidateCollectionName(collectionName); err != nil {
|
||||
return nil, fmt.Errorf("[GetSearchStore] invalid collection name: %w", err)
|
||||
}
|
||||
|
||||
if m.config.EnableCache {
|
||||
if cached := m.getCachedStore(collectionName); cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
store := &oceanbaseSearchStore{
|
||||
manager: m,
|
||||
collectionName: collectionName,
|
||||
tableName: TableName(collectionName),
|
||||
}
|
||||
|
||||
if m.config.EnableCache {
|
||||
m.cacheStore(collectionName, store)
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) recordCollection(ctx context.Context, collectionName, tableName string) error {
|
||||
// Create collections metadata table if not exists
|
||||
createTableSQL := `
|
||||
CREATE TABLE IF NOT EXISTS oceanbase_collections (
|
||||
collection_name VARCHAR(255) PRIMARY KEY,
|
||||
table_name VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
status ENUM('active', 'deleted') DEFAULT 'active'
|
||||
)`
|
||||
|
||||
if err := m.config.Client.GetDB().WithContext(ctx).Exec(createTableSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create collections metadata table: %w", err)
|
||||
}
|
||||
|
||||
// Insert or update collection record
|
||||
upsertSQL := `
|
||||
INSERT INTO oceanbase_collections (collection_name, table_name, status)
|
||||
VALUES (?, ?, 'active')
|
||||
ON DUPLICATE KEY UPDATE
|
||||
table_name = VALUES(table_name),
|
||||
status = 'active',
|
||||
updated_at = CURRENT_TIMESTAMP`
|
||||
|
||||
if err := m.config.Client.GetDB().WithContext(ctx).Exec(upsertSQL, collectionName, tableName).Error; err != nil {
|
||||
return fmt.Errorf("failed to record collection metadata: %w", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Recorded collection: %s (table: %s)", collectionName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) removeCollection(ctx context.Context, collectionName string) error {
|
||||
// Soft delete collection record by setting status to 'deleted'
|
||||
updateSQL := `
|
||||
UPDATE oceanbase_collections
|
||||
SET status = 'deleted', updated_at = CURRENT_TIMESTAMP
|
||||
WHERE collection_name = ?`
|
||||
|
||||
if err := m.config.Client.GetDB().WithContext(ctx).Exec(updateSQL, collectionName).Error; err != nil {
|
||||
return fmt.Errorf("failed to remove collection metadata: %w", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Removed collection record: %s", collectionName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) getCachedStore(collectionName string) searchstore.SearchStore {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if cached, exists := m.cache[collectionName]; exists {
|
||||
if time.Since(cached.lastUsed) < m.config.CacheTTL {
|
||||
cached.lastUsed = time.Now()
|
||||
return cached.store
|
||||
}
|
||||
delete(m.cache, collectionName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) cacheStore(collectionName string, store searchstore.SearchStore) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.cache[collectionName] = &cachedSearchStore{
|
||||
store: store,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) clearCache(collectionName string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.cache, collectionName)
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) startCacheCleaner() {
|
||||
ticker := time.NewTicker(m.config.CacheTTL / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
m.cleanExpiredCache()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) cleanExpiredCache() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, cached := range m.cache {
|
||||
if now.Sub(cached.lastUsed) > m.config.CacheTTL {
|
||||
delete(m.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *oceanbaseManager) getVectorDimension() int {
|
||||
embeddingType := os.Getenv("EMBEDDING_TYPE")
|
||||
|
||||
logs.Infof("[getVectorDimension] EMBEDDING_TYPE: %s", embeddingType)
|
||||
|
||||
switch embeddingType {
|
||||
case "ark":
|
||||
if dimStr := os.Getenv("ARK_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
case "openai":
|
||||
if dimStr := os.Getenv("OPENAI_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
case "ollama":
|
||||
if dimStr := os.Getenv("OLLAMA_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
case "http":
|
||||
if dimStr := os.Getenv("HTTP_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
case "gemini":
|
||||
if dimStr := os.Getenv("GEMINI_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if dimStr := os.Getenv("ARK_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
|
||||
if dimStr := os.Getenv("OPENAI_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
|
||||
if dimStr := os.Getenv("OLLAMA_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
|
||||
if dimStr := os.Getenv("HTTP_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
|
||||
if dimStr := os.Getenv("GEMINI_EMBEDDING_DIMS"); dimStr != "" {
|
||||
if dim, err := strconv.Atoi(dimStr); err == nil {
|
||||
return dim
|
||||
}
|
||||
}
|
||||
|
||||
return 1024
|
||||
}
|
||||
@ -1,366 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type oceanbaseSearchStore struct {
|
||||
manager *oceanbaseManager
|
||||
collectionName string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) {
|
||||
if len(docs) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "Store operation completed in %v for %d documents",
|
||||
time.Since(startTime), len(docs))
|
||||
}()
|
||||
|
||||
var ids []string
|
||||
var vectorDataList []*vectorData
|
||||
|
||||
for _, doc := range docs {
|
||||
content := ExtractContent(doc)
|
||||
if content == "" {
|
||||
logs.CtxWarnf(ctx, "Document %s has no content, skipping", doc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
embeddings, err := s.manager.config.Embedding.EmbedStrings(ctx, []string{content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] failed to embed document: %w", err)
|
||||
}
|
||||
|
||||
if len(embeddings) == 0 {
|
||||
logs.CtxWarnf(ctx, "Failed to generate embedding for document %s", doc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
metadata := BuildMetadata(doc)
|
||||
metadataJSON, err := MetadataToJSON(metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
vectorData := &vectorData{
|
||||
VectorID: doc.ID,
|
||||
Content: content,
|
||||
Metadata: metadataJSON,
|
||||
Embedding: ConvertToFloat32(embeddings[0]),
|
||||
}
|
||||
|
||||
vectorDataList = append(vectorDataList, vectorData)
|
||||
ids = append(ids, doc.ID)
|
||||
}
|
||||
|
||||
if len(vectorDataList) > 0 {
|
||||
if err := s.batchInsertWithRetry(ctx, vectorDataList); err != nil {
|
||||
return nil, fmt.Errorf("[Store] failed to batch insert vector data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Stored %d documents to OceanBase collection: %s", len(ids), s.collectionName)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "Retrieve operation completed in %v", time.Since(startTime))
|
||||
}()
|
||||
|
||||
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(10)}, opts...)
|
||||
|
||||
embeddings, err := s.manager.config.Embedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] failed to embed query: %w", err)
|
||||
}
|
||||
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("[Retrieve] failed to generate embedding for query")
|
||||
}
|
||||
|
||||
results, err := s.manager.config.Client.SearchVectors(
|
||||
ctx,
|
||||
s.collectionName,
|
||||
embeddings[0],
|
||||
ptr.From(options.TopK),
|
||||
0.1,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] failed to search vectors: %w", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "OceanBase returned %d results", len(results))
|
||||
|
||||
documents := make([]*schema.Document, 0, len(results))
|
||||
for _, result := range results {
|
||||
metadata, err := JSONToMetadata(result.Metadata)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "Failed to parse metadata for result %s: %v", result.VectorID, err)
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
|
||||
doc := &schema.Document{
|
||||
ID: result.VectorID,
|
||||
Content: result.Content,
|
||||
MetaData: metadata,
|
||||
}
|
||||
|
||||
similarityScore := result.SimilarityScore
|
||||
logs.CtxInfof(ctx, "Setting score for document %s: %f", result.VectorID, similarityScore)
|
||||
doc.WithScore(similarityScore)
|
||||
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
sort.Slice(documents, func(i, j int) bool {
|
||||
return documents[i].Score() > documents[j].Score()
|
||||
})
|
||||
|
||||
if len(documents) > 0 {
|
||||
s.normalizeScores(documents)
|
||||
}
|
||||
|
||||
if len(documents) > ptr.From(options.TopK) {
|
||||
documents = documents[:ptr.From(options.TopK)]
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Retrieved %d documents from OceanBase collection: %s", len(documents), s.collectionName)
|
||||
for i, doc := range documents {
|
||||
logs.CtxInfof(ctx, "Document %d: ID=%s, Score=%.6f, Content=%s",
|
||||
i+1, doc.ID, doc.Score(), doc.Content[:min(len(doc.Content), 50)])
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "Delete operation completed in %v for %d documents",
|
||||
time.Since(startTime), len(ids))
|
||||
}()
|
||||
|
||||
batchSize := s.manager.config.BatchSize
|
||||
for i := 0; i < len(ids); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
|
||||
batch := ids[i:end]
|
||||
if err := s.deleteBatch(ctx, batch); err != nil {
|
||||
return fmt.Errorf("[Delete] failed to delete batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "Deleted %d documents from OceanBase collection: %s", len(ids), s.collectionName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) batchInsertWithRetry(ctx context.Context, data []*vectorData) error {
|
||||
maxRetries := s.manager.config.MaxRetries
|
||||
retryDelay := s.manager.config.RetryDelay
|
||||
batchSize := s.manager.config.BatchSize
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
err := s.batchInsert(ctx, data, batchSize)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if attempt == maxRetries {
|
||||
return fmt.Errorf("batch insert failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
logs.CtxWarnf(ctx, "Batch insert attempt %d failed, retrying in %v: %v", attempt, retryDelay, err)
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) batchInsert(ctx context.Context, data []*vectorData, batchSize int) error {
|
||||
var vectors []oceanbase.VectorResult
|
||||
for _, item := range data {
|
||||
embedding64 := make([]float64, len(item.Embedding))
|
||||
for i, v := range item.Embedding {
|
||||
embedding64[i] = float64(v)
|
||||
}
|
||||
|
||||
var metadata map[string]interface{}
|
||||
if item.Metadata != "" && item.Metadata != "{}" {
|
||||
if err := json.Unmarshal([]byte(item.Metadata), &metadata); err != nil {
|
||||
logs.CtxWarnf(ctx, "Failed to parse metadata for %s: %v", item.VectorID, err)
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
} else {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
|
||||
metadataStr := "{}"
|
||||
if len(metadata) > 0 {
|
||||
if metadataBytes, err := json.Marshal(metadata); err == nil {
|
||||
metadataStr = string(metadataBytes)
|
||||
}
|
||||
}
|
||||
|
||||
vectors = append(vectors, oceanbase.VectorResult{
|
||||
VectorID: item.VectorID,
|
||||
Content: item.Content,
|
||||
Metadata: metadataStr,
|
||||
Embedding: embedding64,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return s.manager.config.Client.InsertVectors(ctx, s.collectionName, vectors)
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) searchVectorsWithRetry(ctx context.Context, queryEmbedding []float32, limit int, threshold float64) ([]*vectorResult, error) {
|
||||
maxRetries := s.manager.config.MaxRetries
|
||||
retryDelay := s.manager.config.RetryDelay
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
results, err := s.searchVectors(ctx, queryEmbedding, limit, threshold)
|
||||
if err == nil {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
if attempt == maxRetries {
|
||||
return nil, fmt.Errorf("search vectors failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
logs.CtxWarnf(ctx, "Search vectors attempt %d failed, retrying in %v: %v", attempt, retryDelay, err)
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) searchVectors(ctx context.Context, queryEmbedding []float32, limit int, threshold float64) ([]*vectorResult, error) {
|
||||
embedding64 := make([]float64, len(queryEmbedding))
|
||||
for i, v := range queryEmbedding {
|
||||
embedding64[i] = float64(v)
|
||||
}
|
||||
|
||||
results, err := s.manager.config.Client.SearchVectors(ctx, s.collectionName, embedding64, limit, threshold)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search vectors: %w", err)
|
||||
}
|
||||
|
||||
var vectorResults []*vectorResult
|
||||
for _, result := range results {
|
||||
metadataStr := result.Metadata
|
||||
if metadataStr == "" {
|
||||
metadataStr = "{}"
|
||||
}
|
||||
|
||||
vectorResults = append(vectorResults, &vectorResult{
|
||||
VectorID: result.VectorID,
|
||||
Content: result.Content,
|
||||
Metadata: metadataStr,
|
||||
Distance: result.SimilarityScore,
|
||||
})
|
||||
}
|
||||
|
||||
return vectorResults, nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) deleteBatch(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
if err := s.manager.config.Client.DeleteVector(ctx, s.collectionName, id); err != nil {
|
||||
return fmt.Errorf("failed to delete vector %s: %w", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oceanbaseSearchStore) normalizeScores(documents []*schema.Document) {
|
||||
if len(documents) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logs.CtxInfof(context.Background(), "Normalizing scores for %d documents", len(documents))
|
||||
|
||||
|
||||
for i := range documents {
|
||||
originalScore := documents[i].Score()
|
||||
logs.CtxInfof(context.Background(), "Document %d original score: %f", i+1, originalScore)
|
||||
|
||||
if originalScore < 0 {
|
||||
documents[i].WithScore(0.0)
|
||||
logs.CtxInfof(context.Background(), "Document %d score adjusted from %f to 0.0", i+1, originalScore)
|
||||
} else if originalScore > 1 {
|
||||
documents[i].WithScore(1.0)
|
||||
logs.CtxInfof(context.Background(), "Document %d score adjusted from %f to 1.0", i+1, originalScore)
|
||||
} else {
|
||||
logs.CtxInfof(context.Background(), "Document %d score unchanged: %f", i+1, originalScore)
|
||||
}
|
||||
}
|
||||
|
||||
logs.CtxInfof(context.Background(), "Score normalization completed")
|
||||
}
|
||||
|
||||
type vectorData struct {
|
||||
VectorID string
|
||||
Content string
|
||||
Metadata string
|
||||
Embedding []float32
|
||||
}
|
||||
|
||||
type vectorResult struct {
|
||||
ID int64 `json:"id"`
|
||||
VectorID string `json:"vector_id"`
|
||||
Content string `json:"content"`
|
||||
Metadata string `json:"metadata"`
|
||||
Distance float64 `json:"distance"`
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@ -1,81 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func CreateOceanBaseVectorStore(
|
||||
config Config,
|
||||
embedding embedding.Embedder,
|
||||
) (searchstore.Manager, error) {
|
||||
factory := NewFactory(&config)
|
||||
|
||||
manager, err := factory.CreateManager(context.Background(), embedding)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.Infof("Successfully created OceanBase vector store with type: %s", searchstore.TypeVectorStore)
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func CreateOceanBaseVectorStoreWithEnv(
|
||||
embedding embedding.Embedder,
|
||||
) (searchstore.Manager, error) {
|
||||
config := Config{
|
||||
Host: getEnv("OCEANBASE_HOST", "localhost"),
|
||||
Port: getEnvAsInt("OCEANBASE_PORT", 2881),
|
||||
User: getEnv("OCEANBASE_USER", "root"),
|
||||
Password: getEnv("OCEANBASE_PASSWORD", ""),
|
||||
Database: getEnv("OCEANBASE_DATABASE", "test"),
|
||||
}
|
||||
|
||||
return CreateOceanBaseVectorStore(config, embedding)
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvAsInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvAsBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if boolValue, err := strconv.ParseBool(value); err == nil {
|
||||
return boolValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
@ -1,90 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type OceanBaseClient struct {
|
||||
official *OceanBaseOfficialClient
|
||||
}
|
||||
|
||||
func NewOceanBaseClient(dsn string) (*OceanBaseClient, error) {
|
||||
official, err := NewOceanBaseOfficialClient(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &OceanBaseClient{official: official}, nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) CreateCollection(ctx context.Context, collectionName string, dimension int) error {
|
||||
return c.official.CreateCollection(ctx, collectionName, dimension)
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) InsertVectors(ctx context.Context, collectionName string, vectors []VectorResult) error {
|
||||
return c.official.InsertVectors(ctx, collectionName, vectors)
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) SearchVectors(ctx context.Context, collectionName string, queryVector []float64, topK int, threshold float64) ([]VectorResult, error) {
|
||||
return c.official.SearchVectors(ctx, collectionName, queryVector, topK, threshold)
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) SearchVectorsWithStrategy(ctx context.Context, collectionName string, queryVector []float64, topK int, threshold float64, strategy SearchStrategy) ([]VectorResult, error) {
|
||||
return c.official.SearchVectors(ctx, collectionName, queryVector, topK, threshold)
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) GetDB() *gorm.DB {
|
||||
return c.official.GetDB()
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) DebugCollectionData(ctx context.Context, collectionName string) error {
|
||||
return c.official.DebugCollectionData(ctx, collectionName)
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) BatchInsertVectors(ctx context.Context, collectionName string, vectors []VectorResult) error {
|
||||
return c.official.InsertVectors(ctx, collectionName, vectors)
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) DeleteVector(ctx context.Context, collectionName string, vectorID string) error {
|
||||
return c.official.GetDB().WithContext(ctx).Exec("DELETE FROM "+collectionName+" WHERE vector_id = ?", vectorID).Error
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) InitDatabase(ctx context.Context) error {
|
||||
return c.official.GetDB().WithContext(ctx).Exec("SELECT 1").Error
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) DropCollection(ctx context.Context, collectionName string) error {
|
||||
return c.official.GetDB().WithContext(ctx).Exec("DROP TABLE IF EXISTS " + collectionName).Error
|
||||
}
|
||||
|
||||
type SearchStrategy interface {
|
||||
GetThreshold() float64
|
||||
}
|
||||
|
||||
type DefaultSearchStrategy struct{}
|
||||
|
||||
func NewDefaultSearchStrategy() *DefaultSearchStrategy {
|
||||
return &DefaultSearchStrategy{}
|
||||
}
|
||||
|
||||
func (s *DefaultSearchStrategy) GetThreshold() float64 {
|
||||
return 0.0
|
||||
}
|
||||
@ -1,373 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type OceanBaseOfficialClient struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
type VectorResult struct {
|
||||
VectorID string `json:"vector_id"`
|
||||
Content string `json:"content"`
|
||||
Metadata string `json:"metadata"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
SimilarityScore float64 `json:"similarity_score"`
|
||||
Distance float64 `json:"distance"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type CollectionInfo struct {
|
||||
Name string `json:"name"`
|
||||
Dimension int `json:"dimension"`
|
||||
IndexType string `json:"index_type"`
|
||||
}
|
||||
|
||||
func NewOceanBaseOfficialClient(dsn string) (*OceanBaseOfficialClient, error) {
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to OceanBase: %v", err)
|
||||
}
|
||||
|
||||
client := &OceanBaseOfficialClient{db: db}
|
||||
|
||||
if err := client.setVectorParameters(); err != nil {
|
||||
log.Printf("Warning: Failed to set vector parameters: %v", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) setVectorParameters() error {
|
||||
params := map[string]string{
|
||||
"ob_vector_memory_limit_percentage": "30",
|
||||
"ob_query_timeout": "86400000000",
|
||||
"max_allowed_packet": "1073741824",
|
||||
}
|
||||
|
||||
for param, value := range params {
|
||||
if err := c.db.Exec(fmt.Sprintf("SET GLOBAL %s = %s", param, value)).Error; err != nil {
|
||||
log.Printf("Warning: Failed to set %s: %v", param, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) CreateCollection(ctx context.Context, collectionName string, dimension int) error {
|
||||
createTableSQL := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
vector_id VARCHAR(255) PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSON,
|
||||
embedding VECTOR(%d) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_created_at (created_at),
|
||||
INDEX idx_content (content(100))
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
||||
`, collectionName, dimension)
|
||||
|
||||
if err := c.db.WithContext(ctx).Exec(createTableSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create table: %v", err)
|
||||
}
|
||||
|
||||
createIndexSQL := fmt.Sprintf(`
|
||||
CREATE VECTOR INDEX idx_%s_embedding ON %s(embedding)
|
||||
WITH (distance=cosine, type=hnsw, lib=vsag, m=16, ef_construction=200, ef_search=64)
|
||||
`, collectionName, collectionName)
|
||||
|
||||
if err := c.db.WithContext(ctx).Exec(createIndexSQL).Error; err != nil {
|
||||
log.Printf("Warning: Failed to create HNSW vector index, will use exact search: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Successfully created collection '%s' with dimension %d", collectionName, dimension)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) InsertVectors(ctx context.Context, collectionName string, vectors []VectorResult) error {
|
||||
if len(vectors) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
const batchSize = 100
|
||||
for i := 0; i < len(vectors); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(vectors) {
|
||||
end = len(vectors)
|
||||
}
|
||||
batch := vectors[i:end]
|
||||
|
||||
if err := c.insertBatch(ctx, collectionName, batch); err != nil {
|
||||
return fmt.Errorf("failed to insert vectors batch %d-%d: %v", i, end-1, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Successfully inserted %d vectors into collection '%s'", len(vectors), collectionName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) insertBatch(ctx context.Context, collectionName string, batch []VectorResult) error {
|
||||
placeholders := make([]string, len(batch))
|
||||
values := make([]interface{}, 0, len(batch)*5)
|
||||
|
||||
for j, vector := range batch {
|
||||
placeholders[j] = "(?, ?, ?, ?, NOW())"
|
||||
values = append(values,
|
||||
vector.VectorID,
|
||||
vector.Content,
|
||||
vector.Metadata,
|
||||
c.vectorToString(vector.Embedding),
|
||||
)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(`
|
||||
INSERT INTO %s (vector_id, content, metadata, embedding, created_at)
|
||||
VALUES %s
|
||||
ON DUPLICATE KEY UPDATE
|
||||
content = VALUES(content),
|
||||
metadata = VALUES(metadata),
|
||||
embedding = VALUES(embedding),
|
||||
updated_at = NOW()
|
||||
`, collectionName, strings.Join(placeholders, ","))
|
||||
|
||||
return c.db.WithContext(ctx).Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) SearchVectors(
|
||||
ctx context.Context,
|
||||
collectionName string,
|
||||
queryVector []float64,
|
||||
topK int,
|
||||
threshold float64,
|
||||
) ([]VectorResult, error) {
|
||||
|
||||
var count int64
|
||||
if err := c.db.WithContext(ctx).Table(collectionName).Count(&count).Error; err != nil {
|
||||
return nil, fmt.Errorf("collection '%s' does not exist: %v", collectionName, err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
log.Printf("Collection '%s' is empty", collectionName)
|
||||
return []VectorResult{}, nil
|
||||
}
|
||||
|
||||
collectionInfo, err := c.getCollectionInfo(ctx, collectionName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get collection info: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[Debug] Collection info: name=%s, dimension=%d, index_type=%s",
|
||||
collectionName, collectionInfo.Dimension, collectionInfo.IndexType)
|
||||
|
||||
query, params, err := c.buildOptimizedSearchQuery(collectionName, queryVector, topK)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build search query: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[Debug] Built optimized query: %s", query)
|
||||
log.Printf("[Debug] Query params count: %d", len(params))
|
||||
|
||||
var results []VectorResult
|
||||
rows, err := c.db.WithContext(ctx).Raw(query, params...).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute search query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var result VectorResult
|
||||
var embeddingStr string
|
||||
if err := rows.Scan(
|
||||
&result.VectorID,
|
||||
&result.Content,
|
||||
&result.Metadata,
|
||||
&embeddingStr,
|
||||
&result.SimilarityScore,
|
||||
&result.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan result row: %v", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
log.Printf("[Debug] Raw search results count: %d", len(results))
|
||||
|
||||
finalResults := c.postProcessResults(results, topK, threshold)
|
||||
|
||||
log.Printf("[Debug] Final results count: %d", len(finalResults))
|
||||
return finalResults, nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) buildOptimizedSearchQuery(
|
||||
collectionName string,
|
||||
queryVector []float64,
|
||||
topK int,
|
||||
) (string, []interface{}, error) {
|
||||
|
||||
queryVectorStr := c.vectorToString(queryVector)
|
||||
|
||||
similarityExpr := "GREATEST(0, LEAST(1, 1 - COSINE_DISTANCE(embedding, ?)))"
|
||||
orderBy := "COSINE_DISTANCE(embedding, ?) ASC"
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
vector_id,
|
||||
content,
|
||||
metadata,
|
||||
embedding,
|
||||
%s as similarity_score,
|
||||
created_at
|
||||
FROM %s
|
||||
ORDER BY %s
|
||||
APPROXIMATE
|
||||
LIMIT %d
|
||||
`, similarityExpr, collectionName, orderBy, topK*2)
|
||||
|
||||
params := []interface{}{
|
||||
queryVectorStr,
|
||||
queryVectorStr,
|
||||
}
|
||||
|
||||
return query, params, nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) getCollectionInfo(ctx context.Context, collectionName string) (*CollectionInfo, error) {
|
||||
var dimension int
|
||||
|
||||
dimQuery := `
|
||||
SELECT
|
||||
SUBSTRING_INDEX(SUBSTRING_INDEX(COLUMN_TYPE, '(', -1), ')', 1) as dimension
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME = ? AND COLUMN_NAME = 'embedding'
|
||||
`
|
||||
|
||||
if err := c.db.WithContext(ctx).Raw(dimQuery, collectionName).Scan(&dimension).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get vector dimension: %v", err)
|
||||
}
|
||||
|
||||
var indexType string
|
||||
indexQuery := `
|
||||
SELECT INDEX_TYPE
|
||||
FROM INFORMATION_SCHEMA.STATISTICS
|
||||
WHERE TABLE_NAME = ? AND INDEX_NAME LIKE 'idx_%_embedding'
|
||||
`
|
||||
|
||||
if err := c.db.WithContext(ctx).Raw(indexQuery, collectionName).Scan(&indexType).Error; err != nil {
|
||||
indexType = "none"
|
||||
}
|
||||
|
||||
return &CollectionInfo{
|
||||
Name: collectionName,
|
||||
Dimension: dimension,
|
||||
IndexType: indexType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) vectorToString(vector []float64) string {
|
||||
if len(vector) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
|
||||
parts := make([]string, len(vector))
|
||||
for i, v := range vector {
|
||||
parts[i] = fmt.Sprintf("%.6f", v)
|
||||
}
|
||||
return "[" + strings.Join(parts, ",") + "]"
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) postProcessResults(results []VectorResult, topK int, threshold float64) []VectorResult {
|
||||
if len(results) == 0 {
|
||||
return results
|
||||
}
|
||||
|
||||
filtered := make([]VectorResult, 0, len(results))
|
||||
for _, result := range results {
|
||||
if result.SimilarityScore >= threshold {
|
||||
filtered = append(filtered, result)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(filtered, func(i, j int) bool {
|
||||
return filtered[i].SimilarityScore > filtered[j].SimilarityScore
|
||||
})
|
||||
|
||||
if len(filtered) > topK {
|
||||
filtered = filtered[:topK]
|
||||
}
|
||||
|
||||
log.Printf("[Debug] Post-processed results: %d results with threshold %.3f", len(filtered), threshold)
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) GetDB() *gorm.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
func (c *OceanBaseOfficialClient) DebugCollectionData(ctx context.Context, collectionName string) error {
|
||||
var count int64
|
||||
if err := c.db.WithContext(ctx).Table(collectionName).Count(&count).Error; err != nil {
|
||||
log.Printf("[Debug] Collection '%s' does not exist: %v", collectionName, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("[Debug] Collection '%s' exists with %d vectors", collectionName, count)
|
||||
|
||||
log.Printf("[Debug] Sample data from collection '%s':", collectionName)
|
||||
rows, err := c.db.WithContext(ctx).Raw(`
|
||||
SELECT vector_id, content, created_at
|
||||
FROM ` + collectionName + `
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 5
|
||||
`).Rows()
|
||||
if err != nil {
|
||||
log.Printf("[Debug] Failed to get sample data: %v", err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var vectorID, content string
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&vectorID, &content, &createdAt); err != nil {
|
||||
log.Printf("[Debug] Failed to scan sample row: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[Debug] Sample: ID=%s, Content=%s, Created=%s", vectorID, content[:min(50, len(content))], createdAt)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@ -1,112 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package oceanbase
|
||||
|
||||
type VectorIndexConfig struct {
|
||||
Distance string
|
||||
//Index types: hnsw, hnsw_sq, hnsw_bq, ivf_flat, ivf_sq8, ivf_pq
|
||||
Type string
|
||||
//Index library type: vsag, ob
|
||||
Lib string
|
||||
// HNSW Index parameters
|
||||
M *int
|
||||
EfConstruction *int
|
||||
EfSearch *int
|
||||
// IVF Index parameters
|
||||
Nlist *int
|
||||
Nbits *int
|
||||
IVFM *int
|
||||
}
|
||||
|
||||
type VectorData struct {
|
||||
ID int64 `json:"id"`
|
||||
CollectionName string `json:"collection_name"`
|
||||
VectorID string `json:"vector_id"`
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type VectorSearchResult struct {
|
||||
ID int64 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Metadata string `json:"metadata"`
|
||||
Distance float64 `json:"distance"`
|
||||
}
|
||||
|
||||
type VectorMemoryEstimate struct {
|
||||
MinMemoryMB int `json:"min_memory_mb"`
|
||||
RecommendedMemoryMB int `json:"recommended_memory_mb"`
|
||||
EstimatedMemoryMB int `json:"estimated_memory_mb"`
|
||||
}
|
||||
|
||||
const (
|
||||
VectorIndexTypeHNSW = "hnsw"
|
||||
VectorIndexTypeHNSWSQ = "hnsw_sq"
|
||||
VectorIndexTypeHNSWBQ = "hnsw_bq"
|
||||
VectorIndexTypeIVF = "ivf_flat"
|
||||
VectorIndexTypeIVFSQ = "ivf_sq8"
|
||||
VectorIndexTypeIVFPQ = "ivf_pq"
|
||||
)
|
||||
|
||||
const (
|
||||
VectorDistanceTypeL2 = "l2"
|
||||
VectorDistanceTypeCosine = "cosine"
|
||||
VectorDistanceTypeInnerProduct = "inner_product"
|
||||
)
|
||||
|
||||
const (
|
||||
VectorLibTypeVSAG = "vsag"
|
||||
VectorLibTypeOB = "ob"
|
||||
)
|
||||
|
||||
func DefaultVectorIndexConfig() *VectorIndexConfig {
|
||||
m := 16
|
||||
efConstruction := 200
|
||||
efSearch := 64
|
||||
|
||||
return &VectorIndexConfig{
|
||||
Distance: VectorDistanceTypeCosine,
|
||||
Type: VectorIndexTypeHNSW,
|
||||
Lib: VectorLibTypeVSAG,
|
||||
M: &m,
|
||||
EfConstruction: &efConstruction,
|
||||
EfSearch: &efSearch,
|
||||
}
|
||||
}
|
||||
|
||||
func HNSWVectorIndexConfig(distance string, m, efConstruction, efSearch int) *VectorIndexConfig {
|
||||
return &VectorIndexConfig{
|
||||
Distance: distance,
|
||||
Type: VectorIndexTypeHNSW,
|
||||
Lib: VectorLibTypeVSAG,
|
||||
M: &m,
|
||||
EfConstruction: &efConstruction,
|
||||
EfSearch: &efSearch,
|
||||
}
|
||||
}
|
||||
|
||||
func IVFVectorIndexConfig(distance string, nlist, nbits, m int) *VectorIndexConfig {
|
||||
return &VectorIndexConfig{
|
||||
Distance: distance,
|
||||
Type: VectorIndexTypeIVF,
|
||||
Lib: VectorLibTypeOB,
|
||||
Nlist: &nlist,
|
||||
Nbits: &nbits,
|
||||
IVFM: &m,
|
||||
}
|
||||
}
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
@ -31,7 +32,6 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/proxy"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
)
|
||||
@ -182,7 +182,8 @@ func (t *s3Client) PutObjectWithReader(ctx context.Context, objectKey string, co
|
||||
}
|
||||
|
||||
if option.Tagging != nil {
|
||||
input.Tagging = aws.String(goutil.MapToQuery(option.Tagging))
|
||||
tagging := mapToQueryParams(option.Tagging)
|
||||
input.Tagging = aws.String(tagging)
|
||||
}
|
||||
|
||||
// upload object
|
||||
@ -360,6 +361,17 @@ func (t *s3Client) ListObjectsPaginated(ctx context.Context, input *storage.List
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func mapToQueryParams(tagging map[string]string) string {
|
||||
if len(tagging) == 0 {
|
||||
return ""
|
||||
}
|
||||
params := url.Values{}
|
||||
for k, v := range tagging {
|
||||
params.Set(k, v)
|
||||
}
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
func tagsToMap(tags []types.Tag) map[string]string {
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
|
||||
@ -30,10 +30,8 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/proxy"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
)
|
||||
|
||||
type tosClient struct {
|
||||
@ -183,7 +181,7 @@ func (t *tosClient) PutObjectWithReader(ctx context.Context, objectKey string, c
|
||||
}
|
||||
|
||||
if len(option.Tagging) > 0 {
|
||||
input.Tagging = goutil.MapToQuery(option.Tagging)
|
||||
input.Meta = option.Tagging
|
||||
}
|
||||
|
||||
_, err := client.PutObjectV2(ctx, input)
|
||||
@ -279,38 +277,26 @@ func (t *tosClient) ListObjectsPaginated(ctx context.Context, input *storage.Lis
|
||||
continue
|
||||
}
|
||||
|
||||
var tagging map[string]string
|
||||
if obj.Meta != nil {
|
||||
obj.Meta.Range(func(key, value string) bool {
|
||||
if tagging == nil {
|
||||
tagging = make(map[string]string)
|
||||
}
|
||||
tagging[key] = value
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
files = append(files, &storage.FileInfo{
|
||||
Key: obj.Key,
|
||||
LastModified: obj.LastModified,
|
||||
ETag: obj.ETag,
|
||||
Size: obj.Size,
|
||||
Tagging: tagging,
|
||||
})
|
||||
}
|
||||
|
||||
if input.WithTagging {
|
||||
client := t.client
|
||||
taskGroup := taskgroup.NewTaskGroup(ctx, 5)
|
||||
for idx := range files {
|
||||
f := files[idx]
|
||||
taskGroup.Go(func() error {
|
||||
tagging, err := client.GetObjectTagging(ctx, &tos.GetObjectTaggingInput{
|
||||
Bucket: t.bucketName,
|
||||
Key: f.Key,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.Tagging = tagsToMap(tagging.TagSet.Tags)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := taskGroup.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &storage.ListObjectsPaginatedOutput{
|
||||
Files: files,
|
||||
Cursor: output.NextMarker,
|
||||
@ -359,16 +345,3 @@ func (t *tosClient) ListAllObjects(ctx context.Context, prefix string, withTaggi
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func tagsToMap(tags []tos.Tag) map[string]string {
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
m := make(map[string]string, len(tags))
|
||||
for _, tag := range tags {
|
||||
m[tag.Key] = tag.Value
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package goutil
|
||||
|
||||
import "net/url"
|
||||
|
||||
// MapToQuery converts a map[string]string to a URL-encoded query string.
|
||||
func MapToQuery(data map[string]string) string {
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
params := url.Values{}
|
||||
for k, v := range data {
|
||||
params.Set(k, v)
|
||||
}
|
||||
return params.Encode()
|
||||
}
|
||||
@ -31,7 +31,7 @@ export REDIS_PASSWORD=""
|
||||
|
||||
# This Upload component used in Agent / workflow File/Image With LLM , support the component of imagex / storage
|
||||
# default: storage, use the settings of storage component
|
||||
# if imagex, you must finish the configuration of <VolcEngine ImageX>
|
||||
# if imagex, you must finish the configuration of <VolcEngine ImageX>
|
||||
export FILE_UPLOAD_COMPONENT_TYPE="storage"
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ export VE_IMAGEX_DOMAIN=""
|
||||
export VE_IMAGEX_TEMPLATE=""
|
||||
export VE_IMAGEX_UPLOAD_HOST="https://imagex.volcengineapi.com"
|
||||
|
||||
# Storage component
|
||||
# Storage component
|
||||
export STORAGE_TYPE="minio" # minio / tos / s3
|
||||
export STORAGE_UPLOAD_HTTP_SCHEME="http" # http / https. If coze studio website is https, you must set it to https
|
||||
export STORAGE_BUCKET="opencoze"
|
||||
@ -84,7 +84,7 @@ export RMQ_ACCESS_KEY=""
|
||||
export RMQ_SECRET_KEY=""
|
||||
|
||||
# Settings for VectorStore
|
||||
# VectorStore type: milvus / vikingdb / oceanbase
|
||||
# VectorStore type: milvus / vikingdb
|
||||
# If you want to use vikingdb, you need to set up the vikingdb configuration.
|
||||
export VECTOR_STORE_TYPE="milvus"
|
||||
# milvus vector store
|
||||
@ -97,13 +97,6 @@ export VIKING_DB_SK=""
|
||||
export VIKING_DB_SCHEME=""
|
||||
export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to set Embedding settings
|
||||
|
||||
# oceanbase vector store
|
||||
export OCEANBASE_HOST="127.0.0.1"
|
||||
export OCEANBASE_PORT=2881
|
||||
export OCEANBASE_USER="root@test"
|
||||
export OCEANBASE_PASSWORD="coze123"
|
||||
export OCEANBASE_DATABASE="test"
|
||||
|
||||
# Settings for Embedding
|
||||
# The Embedding model relied on by knowledge base vectorization does not need to be configured
|
||||
# if the vector database comes with built-in Embedding functionality (such as VikingDB). Currently,
|
||||
|
||||
@ -27,7 +27,7 @@ export REDIS_PASSWORD=""
|
||||
|
||||
# This Upload component used in Agent / workflow File/Image With LLM , support the component of imagex / storage
|
||||
# default: storage, use the settings of storage component
|
||||
# if imagex, you must finish the configuration of <VolcEngine ImageX>
|
||||
# if imagex, you must finish the configuration of <VolcEngine ImageX>
|
||||
export FILE_UPLOAD_COMPONENT_TYPE="storage"
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ export VE_IMAGEX_DOMAIN=""
|
||||
export VE_IMAGEX_TEMPLATE=""
|
||||
export VE_IMAGEX_UPLOAD_HOST="https://imagex.volcengineapi.com"
|
||||
|
||||
# Storage component
|
||||
# Storage component
|
||||
export STORAGE_TYPE="minio" # minio / tos / s3
|
||||
export STORAGE_UPLOAD_HTTP_SCHEME="http" # http / https. If coze studio website is https, you must set it to https
|
||||
export STORAGE_BUCKET="opencoze"
|
||||
@ -80,7 +80,7 @@ export RMQ_ACCESS_KEY=""
|
||||
export RMQ_SECRET_KEY=""
|
||||
|
||||
# Settings for VectorStore
|
||||
# VectorStore type: milvus / vikingdb / oceanbase
|
||||
# VectorStore type: milvus / vikingdb
|
||||
# If you want to use vikingdb, you need to set up the vikingdb configuration.
|
||||
export VECTOR_STORE_TYPE="milvus"
|
||||
# milvus vector store
|
||||
@ -95,13 +95,6 @@ export VIKING_DB_SK=""
|
||||
export VIKING_DB_SCHEME=""
|
||||
export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to set Embedding settings
|
||||
|
||||
# oceanbase vector store
|
||||
export OCEANBASE_HOST="127.0.0.1"
|
||||
export OCEANBASE_PORT=2881
|
||||
export OCEANBASE_USER="root@test"
|
||||
export OCEANBASE_PASSWORD="coze123"
|
||||
export OCEANBASE_DATABASE="test"
|
||||
|
||||
# Settings for Embedding
|
||||
# The Embedding model relied on by knowledge base vectorization does not need to be configured
|
||||
# if the vector database comes with built-in Embedding functionality (such as VikingDB). Currently,
|
||||
|
||||
@ -1,380 +0,0 @@
|
||||
name: coze-studio
|
||||
# Environment file will be specified via --env-file parameter
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:8.4.5
|
||||
container_name: coze-mysql
|
||||
restart: always
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD:-root}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-opencoze}
|
||||
MYSQL_USER: ${MYSQL_USER:-coze}
|
||||
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-coze123}
|
||||
# ports:
|
||||
# - '3306'
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql
|
||||
- ./volumes/mysql/schema.sql:/docker-entrypoint-initdb.d/init.sql
|
||||
command:
|
||||
- --character-set-server=utf8mb4
|
||||
- --collation-server=utf8mb4_unicode_ci
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD',
|
||||
'mysqladmin',
|
||||
'ping',
|
||||
'-h',
|
||||
'localhost',
|
||||
'-u$${MYSQL_USER}',
|
||||
'-p$${MYSQL_PASSWORD}',
|
||||
]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
redis:
|
||||
image: bitnami/redis:8.0
|
||||
container_name: coze-redis
|
||||
restart: always
|
||||
user: root
|
||||
privileged: true
|
||||
environment:
|
||||
- REDIS_AOF_ENABLED=${REDIS_AOF_ENABLED:-no}
|
||||
- REDIS_PORT_NUMBER=${REDIS_PORT_NUMBER:-6379}
|
||||
- REDIS_IO_THREADS=${REDIS_IO_THREADS:-4}
|
||||
- ALLOW_EMPTY_PASSWORD=${ALLOW_EMPTY_PASSWORD:-yes}
|
||||
# ports:
|
||||
# - '6379'
|
||||
volumes:
|
||||
- ./data/bitnami/redis:/bitnami/redis/data:rw,Z
|
||||
command: >
|
||||
bash -c "
|
||||
/opt/bitnami/scripts/redis/setup.sh
|
||||
# Set proper permissions for data directories
|
||||
chown -R redis:redis /bitnami/redis/data
|
||||
chmod g+s /bitnami/redis/data
|
||||
|
||||
exec /opt/bitnami/scripts/redis/entrypoint.sh /opt/bitnami/scripts/redis/run.sh
|
||||
"
|
||||
healthcheck:
|
||||
test: ['CMD', 'redis-cli', 'ping']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
elasticsearch:
|
||||
image: bitnami/elasticsearch:8.18.0
|
||||
container_name: coze-elasticsearch
|
||||
restart: always
|
||||
user: root
|
||||
privileged: true
|
||||
environment:
|
||||
- TEST=1
|
||||
# Add Java certificate trust configuration
|
||||
# - ES_JAVA_OPTS=-Djdk.tls.client.protocols=TLSv1.2 -Dhttps.protocols=TLSv1.2 -Djavax.net.ssl.trustAll=true -Xms4096m -Xmx4096m
|
||||
# ports:
|
||||
# - '9200'
|
||||
volumes:
|
||||
- ./data/bitnami/elasticsearch:/bitnami/elasticsearch/data
|
||||
- ./volumes/elasticsearch/elasticsearch.yml:/opt/bitnami/elasticsearch/config/my_elasticsearch.yml
|
||||
- ./volumes/elasticsearch/analysis-smartcn.zip:/opt/bitnami/elasticsearch/analysis-smartcn.zip:rw,Z
|
||||
- ./volumes/elasticsearch/setup_es.sh:/setup_es.sh
|
||||
- ./volumes/elasticsearch/es_index_schema:/es_index_schema
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD-SHELL',
|
||||
'curl -f http://localhost:9200 && [ -f /tmp/es_plugins_ready ] && [ -f /tmp/es_init_complete ]',
|
||||
]
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
# Install smartcn analyzer plugin and initialize ES
|
||||
command: >
|
||||
bash -c "
|
||||
/opt/bitnami/scripts/elasticsearch/setup.sh
|
||||
# Set proper permissions for data directories
|
||||
chown -R elasticsearch:elasticsearch /bitnami/elasticsearch/data
|
||||
chmod g+s /bitnami/elasticsearch/data
|
||||
|
||||
# Create plugin directory
|
||||
mkdir -p /bitnami/elasticsearch/plugins;
|
||||
|
||||
# Unzip plugin to plugin directory and set correct permissions
|
||||
echo 'Installing smartcn plugin...';
|
||||
if [ ! -d /opt/bitnami/elasticsearch/plugins/analysis-smartcn ]; then
|
||||
|
||||
# Download plugin package locally
|
||||
echo 'Copying smartcn plugin...';
|
||||
cp /opt/bitnami/elasticsearch/analysis-smartcn.zip /tmp/analysis-smartcn.zip
|
||||
|
||||
elasticsearch-plugin install file:///tmp/analysis-smartcn.zip
|
||||
if [[ "$$?" != "0" ]]; then
|
||||
echo 'Plugin installation failed, exiting operation';
|
||||
rm -rf /opt/bitnami/elasticsearch/plugins/analysis-smartcn
|
||||
exit 1;
|
||||
fi;
|
||||
rm -f /tmp/analysis-smartcn.zip;
|
||||
fi;
|
||||
|
||||
# Create marker file indicating plugin installation success
|
||||
touch /tmp/es_plugins_ready;
|
||||
echo 'Plugin installation successful, marker file created';
|
||||
|
||||
# Start initialization script in background
|
||||
(
|
||||
echo 'Waiting for Elasticsearch to be ready...'
|
||||
until curl -s -f http://localhost:9200/_cat/health >/dev/null 2>&1; do
|
||||
echo 'Elasticsearch not ready, waiting...'
|
||||
sleep 2
|
||||
done
|
||||
echo 'Elasticsearch is ready!'
|
||||
|
||||
# Run ES initialization script
|
||||
echo 'Running Elasticsearch initialization...'
|
||||
sed 's/\r$$//' /setup_es.sh > /setup_es_fixed.sh
|
||||
chmod +x /setup_es_fixed.sh
|
||||
/setup_es_fixed.sh --index-dir /es_index_schema
|
||||
# Create marker file indicating initialization completion
|
||||
touch /tmp/es_init_complete
|
||||
echo 'Elasticsearch initialization completed successfully!'
|
||||
) &
|
||||
|
||||
# Start Elasticsearch
|
||||
exec /opt/bitnami/scripts/elasticsearch/entrypoint.sh /opt/bitnami/scripts/elasticsearch/run.sh
|
||||
echo -e "⏳ Adjusting Elasticsearch disk watermark settings..."
|
||||
"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-06-13T11-33-47Z-cpuv1
|
||||
container_name: coze-minio
|
||||
user: root
|
||||
privileged: true
|
||||
restart: always
|
||||
# ports:
|
||||
# - '9000'
|
||||
# - '9001'
|
||||
volumes:
|
||||
- ./data/minio:/data
|
||||
- ./volumes/minio/default_icon/:/default_icon
|
||||
- ./volumes/minio/official_plugin_icon/:/official_plugin_icon
|
||||
environment:
|
||||
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-minioadmin}
|
||||
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-minioadmin123}
|
||||
MINIO_DEFAULT_BUCKETS: ${STORAGE_BUCKET:-opencoze},${MINIO_DEFAULT_BUCKETS:-oceanbase}
|
||||
entrypoint:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# Run initialization in background
|
||||
(
|
||||
# Wait for MinIO to be ready
|
||||
until (/usr/bin/mc alias set localminio http://localhost:9000 $${MINIO_ROOT_USER} $${MINIO_ROOT_PASSWORD}) do
|
||||
echo "Waiting for MinIO to be ready..."
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Create bucket and copy files
|
||||
/usr/bin/mc mb --ignore-existing localminio/$${STORAGE_BUCKET}
|
||||
/usr/bin/mc cp --recursive /default_icon/ localminio/$${STORAGE_BUCKET}/default_icon/
|
||||
/usr/bin/mc cp --recursive /official_plugin_icon/ localminio/$${STORAGE_BUCKET}/official_plugin_icon/
|
||||
|
||||
echo "MinIO initialization complete."
|
||||
) &
|
||||
|
||||
# Start minio server in foreground
|
||||
exec minio server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD-SHELL',
|
||||
'/usr/bin/mc alias set health_check http://localhost:9000 ${MINIO_ROOT_USER} ${MINIO_ROOT_PASSWORD} && /usr/bin/mc ready health_check',
|
||||
]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
etcd:
|
||||
image: bitnami/etcd:3.5
|
||||
container_name: coze-etcd
|
||||
user: root
|
||||
restart: always
|
||||
privileged: true
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
- ALLOW_NONE_AUTHENTICATION=yes
|
||||
# ports:
|
||||
# - '2379'
|
||||
# - '2380'
|
||||
volumes:
|
||||
- ./data/bitnami/etcd:/bitnami/etcd:rw,Z
|
||||
- ./volumes/etcd/etcd.conf.yml:/opt/bitnami/etcd/conf/etcd.conf.yml:ro,Z
|
||||
command: >
|
||||
bash -c "
|
||||
/opt/bitnami/scripts/etcd/setup.sh
|
||||
# Set proper permissions for data and config directories
|
||||
chown -R etcd:etcd /bitnami/etcd
|
||||
chmod g+s /bitnami/etcd
|
||||
|
||||
exec /opt/bitnami/scripts/etcd/entrypoint.sh /opt/bitnami/scripts/etcd/run.sh
|
||||
"
|
||||
healthcheck:
|
||||
test: ['CMD', 'etcdctl', 'endpoint', 'health']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
# OceanBase for vector storage
|
||||
oceanbase:
|
||||
image: oceanbase/oceanbase-ce:latest
|
||||
container_name: coze-oceanbase
|
||||
restart: always
|
||||
environment:
|
||||
MODE: SLIM
|
||||
OB_DATAFILE_SIZE: 1G
|
||||
OB_SYS_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_TENANT_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
ports:
|
||||
- '2881:2881'
|
||||
volumes:
|
||||
- ./data/oceanbase/ob:/root/ob
|
||||
- ./data/oceanbase/cluster:/root/.obd/cluster
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 4G
|
||||
reservations:
|
||||
memory: 2G
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD-SHELL',
|
||||
'obclient -h127.0.0.1 -P2881 -uroot@test -pcoze123 -e "SELECT 1;"',
|
||||
]
|
||||
interval: 10s
|
||||
retries: 30
|
||||
start_period: 30s
|
||||
timeout: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
nsqlookupd:
|
||||
image: nsqio/nsq:v1.2.1
|
||||
container_name: coze-nsqlookupd
|
||||
command: /nsqlookupd
|
||||
restart: always
|
||||
# ports:
|
||||
# - '4160'
|
||||
# - '4161'
|
||||
networks:
|
||||
- coze-network
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', 'nsqlookupd --version']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
|
||||
nsqd:
|
||||
image: nsqio/nsq:v1.2.1
|
||||
container_name: coze-nsqd
|
||||
command: /nsqd --lookupd-tcp-address=nsqlookupd:4160 --broadcast-address=nsqd
|
||||
restart: always
|
||||
# ports:
|
||||
# - '4150'
|
||||
# - '4151'
|
||||
depends_on:
|
||||
nsqlookupd:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- coze-network
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', '/nsqd --version']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
|
||||
nsqadmin:
|
||||
image: nsqio/nsq:v1.2.1
|
||||
container_name: coze-nsqadmin
|
||||
command: /nsqadmin --lookupd-http-address=nsqlookupd:4161
|
||||
restart: always
|
||||
# ports:
|
||||
# - '4171'
|
||||
depends_on:
|
||||
nsqlookupd:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
coze-server:
|
||||
# build:
|
||||
# context: ../
|
||||
# dockerfile: backend/Dockerfile
|
||||
image: cozedev/coze-studio-server:latest
|
||||
restart: always
|
||||
container_name: coze-server
|
||||
# environment:
|
||||
# LISTEN_ADDR: 0.0.0.0:8888
|
||||
networks:
|
||||
- coze-network
|
||||
# ports:
|
||||
# - '8888'
|
||||
# - '8889'
|
||||
volumes:
|
||||
- .env:/app/.env
|
||||
- ../backend/conf:/app/resources/conf
|
||||
# - ../backend/static:/app/resources/static
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
elasticsearch:
|
||||
condition: service_healthy
|
||||
minio:
|
||||
condition: service_healthy
|
||||
oceanbase:
|
||||
condition: service_healthy
|
||||
command: ['/app/opencoze']
|
||||
|
||||
coze-web:
|
||||
# build:
|
||||
# context: ..
|
||||
# dockerfile: frontend/Dockerfile
|
||||
image: cozedev/coze-studio-web:latest
|
||||
container_name: coze-web
|
||||
restart: always
|
||||
ports:
|
||||
- "${WEB_LISTEN_ADDR:-8888}:80"
|
||||
# - "443:443" # SSL port (uncomment if using SSL)
|
||||
volumes:
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro # Main nginx config
|
||||
- ./nginx/conf.d/default.conf:/etc/nginx/conf.d/default.conf:ro # Proxy config
|
||||
# - ./nginx/ssl:/etc/nginx/ssl:ro # SSL certificates (uncomment if using SSL)
|
||||
depends_on:
|
||||
- coze-server
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
networks:
|
||||
coze-network:
|
||||
driver: bridge
|
||||
@ -1,529 +0,0 @@
|
||||
name: coze-studio-debug
|
||||
|
||||
x-env-file: &env_file
|
||||
- .env.debug
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:8.4.5
|
||||
container_name: coze-mysql
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD:-root}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-opencoze}
|
||||
MYSQL_USER: ${MYSQL_USER:-coze}
|
||||
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-coze123}
|
||||
profiles: ['middleware', 'mysql-setup', 'mysql']
|
||||
env_file: *env_file
|
||||
ports:
|
||||
- '3306:3306'
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql
|
||||
command:
|
||||
- --character-set-server=utf8mb4
|
||||
- --collation-server=utf8mb4_unicode_ci
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD',
|
||||
'mysqladmin',
|
||||
'ping',
|
||||
'-h',
|
||||
'localhost',
|
||||
'-u$${MYSQL_USER}',
|
||||
'-p$${MYSQL_PASSWORD}',
|
||||
]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
redis:
|
||||
image: bitnami/redis:8.0
|
||||
container_name: coze-redis
|
||||
user: root
|
||||
privileged: true
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
environment:
|
||||
- REDIS_AOF_ENABLED=${REDIS_AOF_ENABLED:-no}
|
||||
- REDIS_PORT_NUMBER=${REDIS_PORT_NUMBER:-6379}
|
||||
- REDIS_IO_THREADS=${REDIS_IO_THREADS:-4}
|
||||
- ALLOW_EMPTY_PASSWORD=${ALLOW_EMPTY_PASSWORD:-yes}
|
||||
ports:
|
||||
- '6379:6379'
|
||||
volumes:
|
||||
- ./data/bitnami/redis:/bitnami/redis/data:rw,Z
|
||||
command: >
|
||||
bash -c "
|
||||
/opt/bitnami/scripts/redis/setup.sh
|
||||
# Set proper permissions for data directories
|
||||
chown -R redis:redis /bitnami/redis/data
|
||||
chmod g+s /bitnami/redis/data
|
||||
|
||||
exec /opt/bitnami/scripts/redis/entrypoint.sh /opt/bitnami/scripts/redis/run.sh
|
||||
"
|
||||
depends_on:
|
||||
minio-setup:
|
||||
condition: service_completed_successfully
|
||||
mysql-setup-schema:
|
||||
condition: service_completed_successfully
|
||||
mysql-setup-init-sql:
|
||||
condition: service_completed_successfully
|
||||
healthcheck:
|
||||
test: ['CMD', 'redis-cli', 'ping']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
# rocketmq-namesrv:
|
||||
# image: apache/rocketmq:5.3.2
|
||||
# container_name: coze-rocketmq-namesrv
|
||||
# privileged: true
|
||||
# user: root
|
||||
# profiles: ['middleware']
|
||||
# env_file: *env_file
|
||||
# ports:
|
||||
# - '9876:9876'
|
||||
# volumes:
|
||||
# - ./data/rocketmq/namesrv/logs:/home/rocketmq/logs:rw,Z
|
||||
# - ./data/rocketmq/namesrv/store:/home/rocketmq/store:rw,Z
|
||||
# environment:
|
||||
# - ALLOW_ANONYMOUS_LOGIN=yes
|
||||
# command: >
|
||||
# bash -c "
|
||||
# # Set proper permissions for data directories
|
||||
# mkdir -p /home/rocketmq/logs /home/rocketmq/store
|
||||
# mkdir -p /home/rocketmq/logs/rocketmqlogs
|
||||
# touch /home/rocketmq/logs/rocketmqlogs/tools.log
|
||||
# touch /home/rocketmq/logs/rocketmqlogs/tools_default.log
|
||||
|
||||
# chown -R rocketmq:rocketmq /home/rocketmq/logs /home/rocketmq/store
|
||||
# chmod g+s /home/rocketmq/logs /home/rocketmq/store
|
||||
|
||||
# echo 'Starting RocketMQ NameServer...'
|
||||
# sh mqnamesrv
|
||||
# "
|
||||
# healthcheck:
|
||||
# test: ['CMD', 'sh', 'mqadmin', 'clusterList', '-n', 'localhost:9876']
|
||||
# interval: 5s
|
||||
# timeout: 10s
|
||||
# retries: 10
|
||||
# start_period: 10s
|
||||
# networks:
|
||||
# - coze-network
|
||||
# rocketmq-broker:
|
||||
# image: apache/rocketmq:5.3.2
|
||||
# container_name: coze-rocketmq-broker
|
||||
# privileged: true
|
||||
# user: root
|
||||
# profiles: ['middleware']
|
||||
# env_file: *env_file
|
||||
# ports:
|
||||
# - '10909:10909'
|
||||
# - '10911:10911'
|
||||
# - '10912:10912'
|
||||
# volumes:
|
||||
# - ./data/rocketmq/broker/logs:/home/rocketmq/logs:rw,Z
|
||||
# - ./data/rocketmq/broker/store:/home/rocketmq/store:rw,Z
|
||||
# - ./volumes/rocketmq/broker.conf:/home/rocketmq/conf/broker.conf:rw,Z
|
||||
# networks:
|
||||
# - coze-network
|
||||
# command: >
|
||||
# bash -c '
|
||||
# # Set proper permissions
|
||||
# mkdir -p /home/rocketmq/logs/rocketmqlogs /home/rocketmq/store
|
||||
# touch /home/rocketmq/logs/rocketmqlogs/tools.log \
|
||||
# /home/rocketmq/logs/rocketmqlogs/tools_default.log
|
||||
# chown -R rocketmq:rocketmq /home/rocketmq/logs /home/rocketmq/store
|
||||
# chmod g+s /home/rocketmq/logs /home/rocketmq/store
|
||||
|
||||
# echo "Starting RocketMQ Broker..."
|
||||
# sh mqbroker -n rocketmq-namesrv:9876 -c /home/rocketmq/conf/broker.conf &
|
||||
|
||||
# echo "Waiting for Broker registration..."
|
||||
# broker_ready=false
|
||||
# for i in {1..60}; do
|
||||
# if sh mqadmin clusterList -n rocketmq-namesrv:9876 \
|
||||
# | grep -q "DefaultCluster.*broker-a"; then
|
||||
# echo "Registered."
|
||||
# broker_ready=true
|
||||
# break
|
||||
# fi
|
||||
# echo "Not ready, retry $$i/60..."
|
||||
# sleep 1
|
||||
# done
|
||||
|
||||
# if [ "$$broker_ready" = false ]; then
|
||||
# echo "ERROR: registration timed out."
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
# echo "Creating topics..."
|
||||
# for t in opencoze_knowledge opencoze_search_app opencoze_search_resource \
|
||||
# %RETRY%cg_knowledge %RETRY%cg_search_app %RETRY%cg_search_resource; do
|
||||
# sh mqadmin updateTopic -n rocketmq-namesrv:9876 \
|
||||
# -c DefaultCluster -t "$$t"
|
||||
# done
|
||||
|
||||
# touch /tmp/rocketmq_ready
|
||||
# echo "Broker started successfully."
|
||||
# wait
|
||||
# '
|
||||
# depends_on:
|
||||
# - rocketmq-namesrv
|
||||
# healthcheck:
|
||||
# test: ['CMD-SHELL', '[ -f /tmp/rocketmq_ready ]']
|
||||
# interval: 10s
|
||||
# timeout: 10s
|
||||
# retries: 10
|
||||
# start_period: 10s
|
||||
|
||||
elasticsearch:
|
||||
image: bitnami/elasticsearch:8.18.0
|
||||
container_name: coze-elasticsearch
|
||||
user: root
|
||||
privileged: true
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
environment:
|
||||
- TEST=1
|
||||
# Add Java certificate trust configuration
|
||||
# - ES_JAVA_OPTS=-Djdk.tls.client.protocols=TLSv1.2 -Dhttps.protocols=TLSv1.2 -Djavax.net.ssl.trustAll=true -Xms4096m -Xmx4096m
|
||||
ports:
|
||||
- '9200:9200'
|
||||
volumes:
|
||||
- ./data/bitnami/elasticsearch:/bitnami/elasticsearch/data
|
||||
- ./volumes/elasticsearch/elasticsearch.yml:/opt/bitnami/elasticsearch/config/my_elasticsearch.yml
|
||||
- ./volumes/elasticsearch/analysis-smartcn.zip:/opt/bitnami/elasticsearch/analysis-smartcn.zip:rw,Z
|
||||
- ./volumes/elasticsearch/setup_es.sh:/setup_es.sh
|
||||
- ./volumes/elasticsearch/es_index_schema:/es_index_schemas
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD-SHELL',
|
||||
'curl -f http://localhost:9200 && [ -f /tmp/es_plugins_ready ] && [ -f /tmp/es_init_complete ]',
|
||||
]
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
# Install smartcn analyzer plugin and initialize ES
|
||||
command: >
|
||||
bash -c "
|
||||
/opt/bitnami/scripts/elasticsearch/setup.sh
|
||||
# Set proper permissions for data directories
|
||||
chown -R elasticsearch:elasticsearch /bitnami/elasticsearch/data
|
||||
chmod g+s /bitnami/elasticsearch/data
|
||||
|
||||
# Create plugin directory
|
||||
mkdir -p /bitnami/elasticsearch/plugins;
|
||||
|
||||
# Unzip plugin to plugin directory and set correct permissions
|
||||
echo 'Installing smartcn plugin...';
|
||||
if [ ! -d /opt/bitnami/elasticsearch/plugins/analysis-smartcn ]; then
|
||||
|
||||
# Download plugin package locally
|
||||
echo 'Copying smartcn plugin...';
|
||||
cp /opt/bitnami/elasticsearch/analysis-smartcn.zip /tmp/analysis-smartcn.zip
|
||||
|
||||
elasticsearch-plugin install file:///tmp/analysis-smartcn.zip
|
||||
if [[ "$$?" != "0" ]]; then
|
||||
echo 'Plugin installation failed, exiting operation';
|
||||
rm -rf /opt/bitnami/elasticsearch/plugins/analysis-smartcn
|
||||
exit 1;
|
||||
fi;
|
||||
rm -f /tmp/analysis-smartcn.zip;
|
||||
fi;
|
||||
|
||||
# Create marker file indicating plugin installation success
|
||||
touch /tmp/es_plugins_ready;
|
||||
echo 'Plugin installation successful, marker file created';
|
||||
|
||||
# Start initialization script in background
|
||||
(
|
||||
echo 'Waiting for Elasticsearch to be ready...'
|
||||
until curl -s -f http://localhost:9200/_cat/health >/dev/null 2>&1; do
|
||||
echo 'Elasticsearch not ready, waiting...'
|
||||
sleep 2
|
||||
done
|
||||
echo 'Elasticsearch is ready!'
|
||||
|
||||
# Run ES initialization script
|
||||
echo 'Running Elasticsearch initialization...'
|
||||
sed 's/\r$$//' /setup_es.sh > /setup_es_fixed.sh
|
||||
chmod +x /setup_es_fixed.sh
|
||||
/setup_es_fixed.sh --index-dir /es_index_schemas
|
||||
# Create marker file indicating initialization completion
|
||||
touch /tmp/es_init_complete
|
||||
echo 'Elasticsearch initialization completed successfully!'
|
||||
) &
|
||||
|
||||
# Start Elasticsearch
|
||||
exec /opt/bitnami/scripts/elasticsearch/entrypoint.sh /opt/bitnami/scripts/elasticsearch/run.sh
|
||||
echo -e "⏳ Adjusting Elasticsearch disk watermark settings..."
|
||||
"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-06-13T11-33-47Z-cpuv1
|
||||
container_name: coze-minio
|
||||
user: root
|
||||
privileged: true
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
ports:
|
||||
- '9000:9000'
|
||||
- '9001:9001'
|
||||
volumes:
|
||||
- ./data/minio:/data
|
||||
environment:
|
||||
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-minioadmin}
|
||||
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-minioadmin123}
|
||||
MINIO_DEFAULT_BUCKETS: ${MINIO_BUCKET:-opencoze},${MINIO_DEFAULT_BUCKETS:-oceanbase}
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD-SHELL',
|
||||
'/usr/bin/mc alias set health_check http://localhost:9000 ${MINIO_ROOT_USER} ${MINIO_ROOT_PASSWORD} && /usr/bin/mc ready health_check',
|
||||
]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
etcd:
|
||||
image: bitnami/etcd:3.5
|
||||
container_name: coze-etcd
|
||||
user: root
|
||||
privileged: true
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
- ALLOW_NONE_AUTHENTICATION=yes
|
||||
ports:
|
||||
- 2379:2379
|
||||
- 2380:2380
|
||||
volumes:
|
||||
- ./data/bitnami/etcd:/bitnami/etcd:rw,Z
|
||||
- ./volumes/etcd/etcd.conf.yml:/opt/bitnami/etcd/conf/etcd.conf.yml:ro,Z
|
||||
command: >
|
||||
bash -c "
|
||||
/opt/bitnami/scripts/etcd/setup.sh
|
||||
# Set proper permissions for data and config directories
|
||||
chown -R etcd:etcd /bitnami/etcd
|
||||
chmod g+s /bitnami/etcd
|
||||
|
||||
exec /opt/bitnami/scripts/etcd/entrypoint.sh /opt/bitnami/scripts/etcd/run.sh
|
||||
"
|
||||
healthcheck:
|
||||
test: ['CMD', 'etcdctl', 'endpoint', 'health']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
# OceanBase for vector storage
|
||||
oceanbase:
|
||||
image: oceanbase/oceanbase-ce:latest
|
||||
container_name: coze-oceanbase
|
||||
environment:
|
||||
MODE: SLIM
|
||||
OB_DATAFILE_SIZE: 1G
|
||||
OB_SYS_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_TENANT_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
ports:
|
||||
- '2881:2881'
|
||||
volumes:
|
||||
- ./data/oceanbase/ob:/root/ob
|
||||
- ./data/oceanbase/cluster:/root/.obd/cluster
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 4G
|
||||
reservations:
|
||||
memory: 2G
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
'CMD-SHELL',
|
||||
'obclient -h127.0.0.1 -P2881 -uroot@test -pcoze123 -e "SELECT 1;"',
|
||||
]
|
||||
interval: 10s
|
||||
retries: 30
|
||||
start_period: 30s
|
||||
timeout: 10s
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
nsqlookupd:
|
||||
image: nsqio/nsq:v1.2.1
|
||||
container_name: coze-nsqlookupd
|
||||
command: /nsqlookupd
|
||||
profiles: ['middleware']
|
||||
ports:
|
||||
- '4160:4160'
|
||||
- '4161:4161'
|
||||
networks:
|
||||
- coze-network
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', 'nsqlookupd --version']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
|
||||
nsqd:
|
||||
image: nsqio/nsq:v1.2.1
|
||||
container_name: coze-nsqd
|
||||
command: /nsqd --lookupd-tcp-address=coze-nsqlookupd:4160 --broadcast-address=coze-nsqd
|
||||
profiles: ['middleware']
|
||||
ports:
|
||||
- '4150:4150'
|
||||
- '4151:4151'
|
||||
depends_on:
|
||||
nsqlookupd:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- coze-network
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', '/nsqd --version']
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
|
||||
nsqadmin:
|
||||
image: nsqio/nsq:v1.2.1
|
||||
container_name: coze-nsqadmin
|
||||
command: /nsqadmin --lookupd-http-address=coze-nsqlookupd:4161
|
||||
profiles: ['middleware']
|
||||
ports:
|
||||
- '4171:4171'
|
||||
depends_on:
|
||||
nsqlookupd:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- coze-network
|
||||
|
||||
minio-setup:
|
||||
image: minio/mc:RELEASE.2025-05-21T01-59-54Z-cpuv1
|
||||
container_name: coze-minio-setup
|
||||
profiles: ['middleware']
|
||||
env_file: *env_file
|
||||
depends_on:
|
||||
minio:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./volumes/minio/default_icon/:/default_icon
|
||||
- ./volumes/minio/official_plugin_icon/:/official_plugin_icon
|
||||
entrypoint: >
|
||||
/bin/sh -c "
|
||||
(/usr/bin/mc alias set localminio http://coze-minio:9000 ${MINIO_ROOT_USER} ${MINIO_ROOT_PASSWORD} && \
|
||||
/usr/bin/mc mb --ignore-existing localminio/${STORAGE_BUCKET} && \
|
||||
/usr/bin/mc cp --recursive /default_icon/ localminio/${STORAGE_BUCKET}/default_icon/ && \
|
||||
/usr/bin/mc cp --recursive /official_plugin_icon/ localminio/${STORAGE_BUCKET}/official_plugin_icon/ && \
|
||||
echo 'upload files to minio complete: Files uploaded to ${STORAGE_BUCKET} bucket.') || exit 1; \
|
||||
"
|
||||
networks:
|
||||
- coze-network
|
||||
restart: 'no'
|
||||
|
||||
mysql-setup-schema:
|
||||
image: arigaio/atlas:0.35.0-community-alpine
|
||||
container_name: coze-mysql-setup-schema
|
||||
profiles: ['middleware', 'mysql-setup', 'run-server']
|
||||
env_file: *env_file
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./atlas/opencoze_latest_schema.hcl:/opencoze_latest_schema.hcl
|
||||
entrypoint:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
set -ex
|
||||
TMP_ATLAS_URL="${ATLAS_URL}"
|
||||
if [ "${MYSQL_HOST}" = "localhost" ] || [ "${MYSQL_HOST}" = "127.0.0.1" ]; then
|
||||
echo "MYSQL_HOST is localhost or 127.0.0.1, replacing with docker network address"
|
||||
TMP_ATLAS_URL="mysql://${MYSQL_USER}:${MYSQL_PASSWORD}@mysql:${MYSQL_PORT}/${MYSQL_DATABASE}?charset=utf8mb4&parseTime=True"
|
||||
fi
|
||||
|
||||
echo "final atlas url: $${TMP_ATLAS_URL}"
|
||||
for i in `seq 1 60`; do
|
||||
if atlas schema apply \
|
||||
-u "$${TMP_ATLAS_URL}" \
|
||||
--to file:///opencoze_latest_schema.hcl \
|
||||
--exclude "atlas_schema_revisions,table_*" \
|
||||
--auto-approve; then
|
||||
echo "MySQL setup complete."
|
||||
exit 0
|
||||
fi
|
||||
echo "atlas schema apply failed, retrying...($$i/60)"
|
||||
sleep 1
|
||||
done
|
||||
echo "MySQL setup failed after 60 retries."
|
||||
exit 1
|
||||
networks:
|
||||
- coze-network
|
||||
mysql-setup-init-sql:
|
||||
image: mysql:8.4.5
|
||||
container_name: coze-mysql-setup-init-sql
|
||||
profiles: ['middleware', 'mysql-setup', 'run-server', 'volcano-setup']
|
||||
env_file: *env_file
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
command:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
set -ex
|
||||
for i in $$(seq 1 60); do
|
||||
DB_HOST="$${MYSQL_HOST}"
|
||||
if [ "$${MYSQL_HOST}" = "localhost" ] || [ "$${MYSQL_HOST}" = "127.0.0.1" ]; then
|
||||
DB_HOST="mysql"
|
||||
fi
|
||||
if mysql -h "$${DB_HOST}" -P"$${MYSQL_PORT}" -u"$${MYSQL_USER}" -p"$${MYSQL_PASSWORD}" "$${MYSQL_DATABASE}" < /schema.sql && \
|
||||
mysql -h "$${DB_HOST}" -P"$${MYSQL_PORT}" -u"$${MYSQL_USER}" -p"$${MYSQL_PASSWORD}" "$${MYSQL_DATABASE}" < /sql_init.sql; then
|
||||
echo 'MySQL init success.'
|
||||
exit 0
|
||||
fi
|
||||
echo "Retrying to connect to mysql... ($$i/60)"
|
||||
sleep 1
|
||||
done
|
||||
echo 'Failed to init mysql db.'
|
||||
exit 1
|
||||
volumes:
|
||||
- ./volumes/mysql/sql_init.sql:/sql_init.sql
|
||||
- ./volumes/mysql/schema.sql:/schema.sql
|
||||
networks:
|
||||
- coze-network
|
||||
restart: 'no'
|
||||
|
||||
coze-server:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: backend/Dockerfile
|
||||
image: opencoze/opencoze:latest
|
||||
profiles: ['build-server']
|
||||
|
||||
networks:
|
||||
coze-network:
|
||||
driver: bridge
|
||||
@ -1,362 +0,0 @@
|
||||
# OceanBase Vector Database Integration Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This document provides a comprehensive guide to the integration of OceanBase vector database in Coze Studio, including architectural design, implementation details, configuration instructions, and usage guidelines.
|
||||
|
||||
## Integration Background
|
||||
|
||||
### Why Choose OceanBase?
|
||||
|
||||
1. **Transaction Support**: OceanBase provides complete ACID transaction support, ensuring data consistency
|
||||
2. **Simple Deployment**: Compared to specialized vector databases like Milvus, OceanBase deployment is simpler
|
||||
3. **MySQL Compatibility**: Compatible with MySQL protocol, low learning curve
|
||||
4. **Vector Extensions**: Native support for vector data types and indexing
|
||||
5. **Operations Friendly**: Low operational costs, suitable for small to medium-scale applications
|
||||
|
||||
### Comparison with Milvus
|
||||
|
||||
| Feature | OceanBase | Milvus |
|
||||
| ------------------------------- | -------------------- | --------------------------- |
|
||||
| **Deployment Complexity** | Low (Single Machine) | High (Requires etcd, MinIO) |
|
||||
| **Transaction Support** | Full ACID | Limited |
|
||||
| **Vector Search Speed** | Medium | Faster |
|
||||
| **Storage Efficiency** | Medium | Higher |
|
||||
| **Operational Cost** | Low | High |
|
||||
| **Learning Curve** | Gentle | Steep |
|
||||
|
||||
## Architectural Design
|
||||
|
||||
### Overall Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Coze Studio │ │ OceanBase │ │ Vector Store │
|
||||
│ Application │───▶│ Client │───▶│ Manager │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ OceanBase │
|
||||
│ Database │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
### Core Components
|
||||
|
||||
#### 1. OceanBase Client (`backend/infra/impl/oceanbase/`)
|
||||
|
||||
**Main Files**:
|
||||
|
||||
- `oceanbase.go` - Delegation client, providing backward-compatible interface
|
||||
- `oceanbase_official.go` - Core implementation, based on official documentation
|
||||
- `types.go` - Type definitions
|
||||
|
||||
**Core Functions**:
|
||||
|
||||
```go
|
||||
type OceanBaseClient interface {
|
||||
CreateCollection(ctx context.Context, collectionName string) error
|
||||
InsertVectors(ctx context.Context, collectionName string, vectors []VectorResult) error
|
||||
SearchVectors(ctx context.Context, collectionName string, queryVector []float64, topK int) ([]VectorResult, error)
|
||||
DeleteVector(ctx context.Context, collectionName string, vectorID string) error
|
||||
InitDatabase(ctx context.Context) error
|
||||
DropCollection(ctx context.Context, collectionName string) error
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Search Store Manager (`backend/infra/impl/document/searchstore/oceanbase/`)
|
||||
|
||||
**Main Files**:
|
||||
|
||||
- `oceanbase_manager.go` - Manager implementation
|
||||
- `oceanbase_searchstore.go` - Search store implementation
|
||||
- `factory.go` - Factory pattern creation
|
||||
- `consts.go` - Constant definitions
|
||||
- `convert.go` - Data conversion
|
||||
- `register.go` - Registration functions
|
||||
|
||||
**Core Functions**:
|
||||
|
||||
```go
|
||||
type Manager interface {
|
||||
Create(ctx context.Context, collectionName string) (SearchStore, error)
|
||||
Get(ctx context.Context, collectionName string) (SearchStore, error)
|
||||
Delete(ctx context.Context, collectionName string) error
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. Application Layer Integration (`backend/application/base/appinfra/`)
|
||||
|
||||
**File**: `app_infra.go`
|
||||
|
||||
**Integration Point**:
|
||||
|
||||
```go
|
||||
case "oceanbase":
|
||||
// Build DSN
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
user, password, host, port, database)
|
||||
|
||||
// Create client
|
||||
client, err := oceanbaseClient.NewOceanBaseClient(dsn)
|
||||
|
||||
// Initialize database
|
||||
if err := client.InitDatabase(ctx); err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase database failed, err=%w", err)
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Instructions
|
||||
|
||||
### Environment Variable Configuration
|
||||
|
||||
#### Required Configuration
|
||||
|
||||
```bash
|
||||
# Vector store type
|
||||
VECTOR_STORE_TYPE=oceanbase
|
||||
|
||||
# OceanBase connection configuration
|
||||
OCEANBASE_HOST=localhost
|
||||
OCEANBASE_PORT=2881
|
||||
OCEANBASE_USER=root
|
||||
OCEANBASE_PASSWORD=coze123
|
||||
OCEANBASE_DATABASE=test
|
||||
```
|
||||
|
||||
#### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Performance optimization configuration
|
||||
OCEANBASE_VECTOR_MEMORY_LIMIT_PERCENTAGE=30
|
||||
OCEANBASE_BATCH_SIZE=100
|
||||
OCEANBASE_MAX_OPEN_CONNS=100
|
||||
OCEANBASE_MAX_IDLE_CONNS=10
|
||||
|
||||
# Cache configuration
|
||||
OCEANBASE_ENABLE_CACHE=true
|
||||
OCEANBASE_CACHE_TTL=300
|
||||
|
||||
# Monitoring configuration
|
||||
OCEANBASE_ENABLE_METRICS=true
|
||||
OCEANBASE_ENABLE_SLOW_QUERY_LOG=true
|
||||
|
||||
# Retry configuration
|
||||
OCEANBASE_MAX_RETRIES=3
|
||||
OCEANBASE_RETRY_DELAY=1
|
||||
OCEANBASE_CONN_TIMEOUT=30
|
||||
```
|
||||
|
||||
### Docker Configuration
|
||||
|
||||
#### docker-compose-oceanbase.yml
|
||||
|
||||
```yaml
|
||||
oceanbase:
|
||||
image: oceanbase/oceanbase-ce:latest
|
||||
container_name: coze-oceanbase
|
||||
environment:
|
||||
MODE: SLIM
|
||||
OB_DATAFILE_SIZE: 1G
|
||||
OB_SYS_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_TENANT_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
ports:
|
||||
- '2881:2881'
|
||||
volumes:
|
||||
- ./data/oceanbase/ob:/root/ob
|
||||
- ./data/oceanbase/cluster:/root/.obd/cluster
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 4G
|
||||
reservations:
|
||||
memory: 2G
|
||||
```
|
||||
|
||||
## Usage Guide
|
||||
|
||||
### 1. Quick Start
|
||||
|
||||
```bash
|
||||
# Clone the project
|
||||
git clone https://github.com/coze-dev/coze-studio.git
|
||||
cd coze-studio
|
||||
|
||||
# Setup OceanBase environment
|
||||
make oceanbase_env
|
||||
|
||||
# Start OceanBase debug environment
|
||||
make oceanbase_debug
|
||||
```
|
||||
|
||||
### 2. Verify Deployment
|
||||
|
||||
```bash
|
||||
# Check container status
|
||||
docker ps | grep oceanbase
|
||||
|
||||
# Test connection
|
||||
mysql -h localhost -P 2881 -u root -p -e "SELECT 1;"
|
||||
|
||||
# View databases
|
||||
mysql -h localhost -P 2881 -u root -p -e "SHOW DATABASES;"
|
||||
```
|
||||
|
||||
### 3. Create Knowledge Base
|
||||
|
||||
In the Coze Studio interface:
|
||||
|
||||
1. Enter knowledge base management
|
||||
2. Select OceanBase as vector storage
|
||||
3. Upload documents for vectorization
|
||||
4. Test vector retrieval functionality
|
||||
|
||||
### 4. Performance Monitoring
|
||||
|
||||
```bash
|
||||
# View container resource usage
|
||||
docker stats coze-oceanbase
|
||||
|
||||
# View slow query logs
|
||||
docker logs coze-oceanbase | grep "slow query"
|
||||
|
||||
# View connection count
|
||||
mysql -h localhost -P 2881 -u root -p -e "SHOW PROCESSLIST;"
|
||||
```
|
||||
|
||||
## Integration Features
|
||||
|
||||
### 1. Design Principles
|
||||
|
||||
#### Architecture Compatibility Design
|
||||
|
||||
- Strictly follow Coze Studio core architectural design principles, ensuring seamless integration of OceanBase adaptation layer with existing systems
|
||||
- Adopt delegation pattern (Delegation Pattern) to achieve backward compatibility, ensuring stability and consistency of existing interfaces
|
||||
- Maintain complete compatibility with existing vector storage interfaces, ensuring smooth system migration and upgrade
|
||||
|
||||
#### Performance First
|
||||
|
||||
- Use HNSW index to achieve efficient approximate nearest neighbor search
|
||||
- Batch operations reduce database interaction frequency
|
||||
- Connection pool management optimizes resource usage
|
||||
|
||||
#### Easy Deployment
|
||||
|
||||
- Single machine deployment, no complex cluster configuration required
|
||||
- Docker one-click deployment
|
||||
- Environment variable configuration, flexible and easy to use
|
||||
|
||||
### 2. Technical Highlights
|
||||
|
||||
#### Delegation Pattern Design
|
||||
|
||||
```go
|
||||
type OceanBaseClient struct {
|
||||
official *OceanBaseOfficialClient
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) CreateCollection(ctx context.Context, collectionName string) error {
|
||||
return c.official.CreateCollection(ctx, collectionName)
|
||||
}
|
||||
```
|
||||
|
||||
#### Intelligent Configuration Management
|
||||
|
||||
```go
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Host: getEnv("OCEANBASE_HOST", "localhost"),
|
||||
Port: getEnvAsInt("OCEANBASE_PORT", 2881),
|
||||
User: getEnv("OCEANBASE_USER", "root"),
|
||||
Password: getEnv("OCEANBASE_PASSWORD", ""),
|
||||
Database: getEnv("OCEANBASE_DATABASE", "test"),
|
||||
// ... other configurations
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Error Handling Optimization
|
||||
|
||||
```go
|
||||
func (c *OceanBaseOfficialClient) setVectorParameters() error {
|
||||
params := map[string]string{
|
||||
"ob_vector_memory_limit_percentage": "30",
|
||||
"ob_query_timeout": "86400000000",
|
||||
"max_allowed_packet": "1073741824",
|
||||
}
|
||||
|
||||
for param, value := range params {
|
||||
if err := c.db.Exec(fmt.Sprintf("SET GLOBAL %s = %s", param, value)).Error; err != nil {
|
||||
log.Printf("Warning: Failed to set %s: %v", param, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### 1. Common Issues
|
||||
|
||||
#### Connection Issues
|
||||
|
||||
```bash
|
||||
# Check container status
|
||||
docker ps | grep oceanbase
|
||||
|
||||
# Check port mapping
|
||||
docker port coze-oceanbase
|
||||
|
||||
# Test connection
|
||||
mysql -h localhost -P 2881 -u root -p -e "SELECT 1;"
|
||||
```
|
||||
|
||||
#### Vector Index Issues
|
||||
|
||||
```sql
|
||||
-- Check index status
|
||||
SHOW INDEX FROM test_vectors;
|
||||
|
||||
-- Rebuild index
|
||||
DROP INDEX idx_test_embedding ON test_vectors;
|
||||
CREATE VECTOR INDEX idx_test_embedding ON test_vectors(embedding)
|
||||
WITH (distance=cosine, type=hnsw, lib=vsag, m=16, ef_construction=200, ef_search=64);
|
||||
```
|
||||
|
||||
#### Performance Issues
|
||||
|
||||
```sql
|
||||
-- Adjust memory limit
|
||||
SET GLOBAL ob_vector_memory_limit_percentage = 50;
|
||||
|
||||
-- View slow queries
|
||||
SHOW VARIABLES LIKE 'slow_query_log';
|
||||
```
|
||||
|
||||
### 2. Log Analysis
|
||||
|
||||
```bash
|
||||
# View OceanBase logs
|
||||
docker logs coze-oceanbase
|
||||
|
||||
# View application logs
|
||||
tail -f logs/coze-studio.log | grep -i "oceanbase\|vector"
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
The integration of OceanBase vector database in Coze Studio has achieved the following goals:
|
||||
|
||||
1. **Complete Functionality**: Supports complete vector storage and retrieval functionality
|
||||
2. **Good Performance**: Achieves efficient vector search through HNSW indexing
|
||||
3. **Simple Deployment**: Single machine deployment, no complex configuration required
|
||||
4. **Operations Friendly**: Low operational costs, easy monitoring and management
|
||||
5. **Strong Scalability**: Supports horizontal and vertical scaling
|
||||
|
||||
Through this integration, Coze Studio provides users with a simple, efficient, and reliable vector database solution, particularly suitable for scenarios requiring transaction support, simple deployment, and low operational costs.
|
||||
|
||||
## Related Links
|
||||
|
||||
- [OceanBase Official Documentation](https://www.oceanbase.com/docs)
|
||||
- [Coze Studio Project Repository](https://github.com/coze-dev/coze-studio)
|
||||
@ -1,364 +0,0 @@
|
||||
# OceanBase 向量数据库集成指南
|
||||
|
||||
## 概述
|
||||
|
||||
本文档详细介绍了 OceanBase 向量数据库在 Coze Studio 中的集成适配情况,包括架构设计、实现细节、配置说明和使用指南。
|
||||
|
||||
## 集成背景
|
||||
|
||||
### 为什么选择 OceanBase?
|
||||
|
||||
1. **事务支持**: OceanBase 提供完整的 ACID 事务支持,确保数据一致性
|
||||
2. **部署简单**: 相比 Milvus 等专用向量数据库,OceanBase 部署更简单
|
||||
3. **MySQL 兼容**: 兼容 MySQL 协议,学习成本低
|
||||
4. **向量扩展**: 原生支持向量数据类型和索引
|
||||
5. **运维友好**: 运维成本低,适合中小规模应用
|
||||
|
||||
### 与 Milvus 的对比
|
||||
|
||||
| 特性 | OceanBase | Milvus |
|
||||
| ---------------------- | -------------- | ---------------------- |
|
||||
| **部署复杂度** | 低(单机部署) | 高(需要 etcd、MinIO) |
|
||||
| **事务支持** | 完整 ACID | 有限 |
|
||||
| **向量检索速度** | 中等 | 更快 |
|
||||
| **存储效率** | 中等 | 更高 |
|
||||
| **运维成本** | 低 | 高 |
|
||||
| **学习曲线** | 平缓 | 陡峭 |
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 整体架构
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Coze Studio │ │ OceanBase │ │ Vector Store │
|
||||
│ Application │───▶│ Client │───▶│ Manager │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ OceanBase │
|
||||
│ Database │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
### 核心组件
|
||||
|
||||
#### 1. OceanBase Client (`backend/infra/impl/oceanbase/`)
|
||||
|
||||
**主要文件**:
|
||||
|
||||
- `oceanbase.go` - 委托客户端,提供向后兼容接口
|
||||
- `oceanbase_official.go` - 核心实现,基于官方文档
|
||||
- `types.go` - 类型定义
|
||||
|
||||
**核心功能**:
|
||||
|
||||
```go
|
||||
type OceanBaseClient interface {
|
||||
CreateCollection(ctx context.Context, collectionName string) error
|
||||
InsertVectors(ctx context.Context, collectionName string, vectors []VectorResult) error
|
||||
SearchVectors(ctx context.Context, collectionName string, queryVector []float64, topK int) ([]VectorResult, error)
|
||||
DeleteVector(ctx context.Context, collectionName string, vectorID string) error
|
||||
InitDatabase(ctx context.Context) error
|
||||
DropCollection(ctx context.Context, collectionName string) error
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Search Store Manager (`backend/infra/impl/document/searchstore/oceanbase/`)
|
||||
|
||||
**主要文件**:
|
||||
|
||||
- `oceanbase_manager.go` - 管理器实现
|
||||
- `oceanbase_searchstore.go` - 搜索存储实现
|
||||
- `factory.go` - 工厂模式创建
|
||||
- `consts.go` - 常量定义
|
||||
- `convert.go` - 数据转换
|
||||
- `register.go` - 注册函数
|
||||
|
||||
**核心功能**:
|
||||
|
||||
```go
|
||||
type Manager interface {
|
||||
Create(ctx context.Context, collectionName string) (SearchStore, error)
|
||||
Get(ctx context.Context, collectionName string) (SearchStore, error)
|
||||
Delete(ctx context.Context, collectionName string) error
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. 应用层集成 (`backend/application/base/appinfra/`)
|
||||
|
||||
**文件**: `app_infra.go`
|
||||
|
||||
**集成点**:
|
||||
|
||||
```go
|
||||
case "oceanbase":
|
||||
// 构建 DSN
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
user, password, host, port, database)
|
||||
|
||||
// 创建客户端
|
||||
client, err := oceanbaseClient.NewOceanBaseClient(dsn)
|
||||
|
||||
// 初始化数据库
|
||||
if err := client.InitDatabase(ctx); err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase database failed, err=%w", err)
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 环境变量配置
|
||||
|
||||
#### 必需配置
|
||||
|
||||
```bash
|
||||
# 向量存储类型
|
||||
VECTOR_STORE_TYPE=oceanbase
|
||||
|
||||
# OceanBase 连接配置
|
||||
OCEANBASE_HOST=localhost
|
||||
OCEANBASE_PORT=2881
|
||||
OCEANBASE_USER=root
|
||||
OCEANBASE_PASSWORD=coze123
|
||||
OCEANBASE_DATABASE=test
|
||||
```
|
||||
|
||||
#### 可选配置
|
||||
|
||||
```bash
|
||||
# 性能优化配置
|
||||
OCEANBASE_VECTOR_MEMORY_LIMIT_PERCENTAGE=30
|
||||
OCEANBASE_BATCH_SIZE=100
|
||||
OCEANBASE_MAX_OPEN_CONNS=100
|
||||
OCEANBASE_MAX_IDLE_CONNS=10
|
||||
|
||||
# 缓存配置
|
||||
OCEANBASE_ENABLE_CACHE=true
|
||||
OCEANBASE_CACHE_TTL=300
|
||||
|
||||
# 监控配置
|
||||
OCEANBASE_ENABLE_METRICS=true
|
||||
OCEANBASE_ENABLE_SLOW_QUERY_LOG=true
|
||||
|
||||
# 重试配置
|
||||
OCEANBASE_MAX_RETRIES=3
|
||||
OCEANBASE_RETRY_DELAY=1
|
||||
OCEANBASE_CONN_TIMEOUT=30
|
||||
```
|
||||
|
||||
### Docker 配置
|
||||
|
||||
#### docker-compose-oceanbase.yml
|
||||
|
||||
```yaml
|
||||
oceanbase:
|
||||
image: oceanbase/oceanbase-ce:latest
|
||||
container_name: coze-oceanbase
|
||||
environment:
|
||||
MODE: SLIM
|
||||
OB_DATAFILE_SIZE: 1G
|
||||
OB_SYS_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
OB_TENANT_PASSWORD: ${OCEANBASE_PASSWORD:-coze123}
|
||||
ports:
|
||||
- '2881:2881'
|
||||
volumes:
|
||||
- ./data/oceanbase/ob:/root/ob
|
||||
- ./data/oceanbase/cluster:/root/.obd/cluster
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 4G
|
||||
reservations:
|
||||
memory: 2G
|
||||
```
|
||||
|
||||
## 使用指南
|
||||
|
||||
### 1. 快速启动
|
||||
|
||||
```bash
|
||||
# 克隆项目
|
||||
git clone https://github.com/coze-dev/coze-studio.git
|
||||
cd coze-studio
|
||||
|
||||
# 设置 OceanBase 环境文件
|
||||
make oceanbase_env
|
||||
|
||||
# 启动 OceanBase 调试环境
|
||||
make oceanbase_debug
|
||||
```
|
||||
|
||||
### 2. 验证部署
|
||||
|
||||
```bash
|
||||
# 检查容器状态
|
||||
docker ps | grep oceanbase
|
||||
|
||||
# 测试连接
|
||||
mysql -h localhost -P 2881 -u root -p -e "SELECT 1;"
|
||||
|
||||
# 查看数据库
|
||||
mysql -h localhost -P 2881 -u root -p -e "SHOW DATABASES;"
|
||||
```
|
||||
|
||||
### 3. 创建知识库
|
||||
|
||||
在 Coze Studio 界面中:
|
||||
|
||||
1. 进入知识库管理
|
||||
2. 选择 OceanBase 作为向量存储
|
||||
3. 上传文档进行向量化
|
||||
4. 测试向量检索功能
|
||||
|
||||
### 4. 性能监控
|
||||
|
||||
```bash
|
||||
# 查看容器资源使用
|
||||
docker stats coze-oceanbase
|
||||
|
||||
# 查看慢查询日志
|
||||
docker logs coze-oceanbase | grep "slow query"
|
||||
|
||||
# 查看连接数
|
||||
mysql -h localhost -P 2881 -u root -p -e "SHOW PROCESSLIST;"
|
||||
```
|
||||
|
||||
## 适配特点
|
||||
|
||||
### 1. 设计原则
|
||||
|
||||
#### 架构兼容性设计
|
||||
|
||||
- 严格遵循 Coze Studio 核心架构设计原则,确保 OceanBase 适配层与现有系统无缝集成
|
||||
- 采用委托模式(Delegation Pattern)实现向后兼容,保证现有接口的稳定性和一致性
|
||||
- 保持与现有向量存储接口的完全兼容,确保系统平滑迁移和升级
|
||||
|
||||
#### 性能优先
|
||||
|
||||
- 使用 HNSW 索引实现高效的近似最近邻搜索
|
||||
- 批量操作减少数据库交互次数
|
||||
- 连接池管理优化资源使用
|
||||
|
||||
#### 易于部署
|
||||
|
||||
- 单机部署,无需复杂的集群配置
|
||||
- Docker 一键部署
|
||||
- 环境变量配置,灵活易用
|
||||
|
||||
### 2. 技术亮点
|
||||
|
||||
#### 委托模式设计
|
||||
|
||||
```go
|
||||
type OceanBaseClient struct {
|
||||
official *OceanBaseOfficialClient
|
||||
}
|
||||
|
||||
func (c *OceanBaseClient) CreateCollection(ctx context.Context, collectionName string) error {
|
||||
return c.official.CreateCollection(ctx, collectionName)
|
||||
}
|
||||
```
|
||||
|
||||
#### 智能配置管理
|
||||
|
||||
```go
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Host: getEnv("OCEANBASE_HOST", "localhost"),
|
||||
Port: getEnvAsInt("OCEANBASE_PORT", 2881),
|
||||
User: getEnv("OCEANBASE_USER", "root"),
|
||||
Password: getEnv("OCEANBASE_PASSWORD", ""),
|
||||
Database: getEnv("OCEANBASE_DATABASE", "test"),
|
||||
// ... 其他配置
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 错误处理优化
|
||||
|
||||
```go
|
||||
func (c *OceanBaseOfficialClient) setVectorParameters() error {
|
||||
params := map[string]string{
|
||||
"ob_vector_memory_limit_percentage": "30",
|
||||
"ob_query_timeout": "86400000000",
|
||||
"max_allowed_packet": "1073741824",
|
||||
}
|
||||
|
||||
for param, value := range params {
|
||||
if err := c.db.Exec(fmt.Sprintf("SET GLOBAL %s = %s", param, value)).Error; err != nil {
|
||||
log.Printf("Warning: Failed to set %s: %v", param, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 1. 常见问题
|
||||
|
||||
#### 连接问题
|
||||
|
||||
```bash
|
||||
# 检查容器状态
|
||||
docker ps | grep oceanbase
|
||||
|
||||
# 检查端口映射
|
||||
docker port coze-oceanbase
|
||||
|
||||
# 测试连接
|
||||
mysql -h localhost -P 2881 -u root -p -e "SELECT 1;"
|
||||
```
|
||||
|
||||
#### 向量索引问题
|
||||
|
||||
```sql
|
||||
-- 检查索引状态
|
||||
SHOW INDEX FROM test_vectors;
|
||||
|
||||
-- 重建索引
|
||||
DROP INDEX idx_test_embedding ON test_vectors;
|
||||
CREATE VECTOR INDEX idx_test_embedding ON test_vectors(embedding)
|
||||
WITH (distance=cosine, type=hnsw, lib=vsag, m=16, ef_construction=200, ef_search=64);
|
||||
```
|
||||
|
||||
#### 性能问题
|
||||
|
||||
```sql
|
||||
-- 调整内存限制
|
||||
SET GLOBAL ob_vector_memory_limit_percentage = 50;
|
||||
|
||||
-- 查看慢查询
|
||||
SHOW VARIABLES LIKE 'slow_query_log';
|
||||
```
|
||||
|
||||
### 2. 日志分析
|
||||
|
||||
```bash
|
||||
# 查看 OceanBase 日志
|
||||
docker logs coze-oceanbase
|
||||
|
||||
# 查看应用日志
|
||||
tail -f logs/coze-studio.log | grep -i "oceanbase\|vector"
|
||||
```
|
||||
|
||||
## 总结
|
||||
|
||||
OceanBase 向量数据库在 Coze Studio 中的集成实现了以下目标:
|
||||
|
||||
1. **功能完整**: 支持完整的向量存储和检索功能
|
||||
2. **性能良好**: 通过 HNSW 索引实现高效的向量搜索
|
||||
3. **部署简单**: 单机部署,无需复杂配置
|
||||
4. **运维友好**: 低运维成本,易于监控和管理
|
||||
5. **扩展性强**: 支持水平扩展和垂直扩展
|
||||
|
||||
通过这次集成,Coze Studio 为用户提供了一个简单、高效、可靠的向量数据库解决方案,特别适合需要事务支持、部署简单、运维成本低的场景。
|
||||
|
||||
## 相关链接
|
||||
|
||||
- [OceanBase 官方文档](https://www.oceanbase.com/docs)
|
||||
- [Coze Studio 项目地址](https://github.com/coze-dev/coze-studio)
|
||||
@ -1,88 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright 2025 coze-dev Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
# OceanBase Environment Configuration Script
|
||||
# Dynamically modify vector store type in environment files
|
||||
|
||||
set -e
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
# Script directory
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
DOCKER_DIR="$PROJECT_ROOT/../docker"
|
||||
|
||||
# Environment type
|
||||
ENV_TYPE="${1:-debug}"
|
||||
|
||||
# Validate environment type
|
||||
if [[ "$ENV_TYPE" != "debug" && "$ENV_TYPE" != "env" ]]; then
|
||||
echo -e "${RED}Error: Invalid environment type '$ENV_TYPE'${NC}"
|
||||
echo "Usage: $0 [debug|env]"
|
||||
echo " debug - Test environment (.env.debug)"
|
||||
echo " env - Production environment (.env)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Determine target environment file
|
||||
if [[ "$ENV_TYPE" == "debug" ]]; then
|
||||
TARGET_ENV_FILE="$DOCKER_DIR/.env.debug"
|
||||
else
|
||||
TARGET_ENV_FILE="$DOCKER_DIR/.env"
|
||||
fi
|
||||
|
||||
# Check if target environment file exists
|
||||
if [[ ! -f "$TARGET_ENV_FILE" ]]; then
|
||||
if [[ "$ENV_TYPE" == "debug" ]]; then
|
||||
cp "$DOCKER_DIR/.env.debug.example" "$TARGET_ENV_FILE"
|
||||
else
|
||||
cp "$DOCKER_DIR/.env.example" "$TARGET_ENV_FILE"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if already configured for OceanBase
|
||||
if grep -q "VECTOR_STORE_TYPE.*oceanbase" "$TARGET_ENV_FILE"; then
|
||||
echo -e "${YELLOW}Already configured for OceanBase${NC}"
|
||||
else
|
||||
echo -e "${GREEN}Configuring OceanBase...${NC}"
|
||||
|
||||
# Use sed to replace VECTOR_STORE_TYPE
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS requires special handling - use temporary file to avoid .bak creation
|
||||
sed "s/export VECTOR_STORE_TYPE=\"milvus\"/export VECTOR_STORE_TYPE=\"oceanbase\"/g" "$TARGET_ENV_FILE" > "$TARGET_ENV_FILE.tmp"
|
||||
sed "s/export VECTOR_STORE_TYPE=\"vikingdb\"/export VECTOR_STORE_TYPE=\"oceanbase\"/g" "$TARGET_ENV_FILE.tmp" > "$TARGET_ENV_FILE"
|
||||
rm -f "$TARGET_ENV_FILE.tmp"
|
||||
else
|
||||
# Linux systems
|
||||
sed -i "s/export VECTOR_STORE_TYPE=\"milvus\"/export VECTOR_STORE_TYPE=\"oceanbase\"/g" "$TARGET_ENV_FILE"
|
||||
sed -i "s/export VECTOR_STORE_TYPE=\"vikingdb\"/export VECTOR_STORE_TYPE=\"oceanbase\"/g" "$TARGET_ENV_FILE"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Verify configuration
|
||||
if grep -q "VECTOR_STORE_TYPE.*oceanbase" "$TARGET_ENV_FILE"; then
|
||||
echo -e "${GREEN}✅ OceanBase configured successfully${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ Failed to configure OceanBase${NC}"
|
||||
exit 1
|
||||
fi
|
||||
18
scripts/setup/server.sh
Normal file → Executable file
18
scripts/setup/server.sh
Normal file → Executable file
@ -1,20 +1,4 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright 2025 coze-dev Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
BASE_DIR="$(cd "$SCRIPT_DIR/../../" && pwd)"
|
||||
@ -28,8 +12,6 @@ ENV_FILE="$DOCKER_DIR/.env"
|
||||
|
||||
if [[ "$APP_ENV" == "debug" ]]; then
|
||||
ENV_FILE="$DOCKER_DIR/.env.debug"
|
||||
elif [[ "$APP_ENV" == "oceanbase" ]]; then
|
||||
ENV_FILE="$DOCKER_DIR/.env"
|
||||
fi
|
||||
|
||||
source "$ENV_FILE"
|
||||
|
||||
Reference in New Issue
Block a user