From 4789889b17e791730eeb963a996e7a95c605efe2 Mon Sep 17 00:00:00 2001 From: Ryo Date: Mon, 29 Sep 2025 16:11:49 +0800 Subject: [PATCH] refactor: optimize app infra component initialization (#2294) --- .../api/handler/coze/workflow_service_test.go | 11 +- backend/application/application.go | 7 + .../application/base/appinfra/app_infra.go | 673 +----------------- backend/application/base/appinfra/modelmgr.go | 2 + backend/application/workflow/init.go | 7 +- .../buildinmodel}/builtin_chat_model.go | 12 +- backend/bizpkg/fileutil/fileutil.go | 51 ++ .../{pkg/goutil => bizpkg/fileutil}/pyutil.go | 2 +- backend/crossdomain/impl/code/code.go | 31 - .../internal/agentflow/node_tool_database.go | 5 +- .../domain/knowledge/service/event_handle.go | 10 +- backend/domain/knowledge/service/knowledge.go | 7 +- backend/domain/knowledge/service/retrieve.go | 9 +- .../domain/knowledge/service/retrieve_test.go | 4 + .../memory/database/service/database_impl.go | 15 +- .../database/service/database_impl_test.go | 4 + .../internal/canvas/adaptor/canvas_test.go | 3 +- .../workflow/internal/nodes/code/code.go | 3 +- backend/infra/cache/impl/redis/redis.go | 2 + backend/infra/coderunner/code.go | 10 + .../infra/coderunner/impl/direct/runner.go | 4 +- backend/infra/coderunner/impl/impl.go | 66 ++ .../infra/coderunner/impl/sandbox/runner.go | 6 +- .../impl/builtin/messages_to_query.go | 2 +- .../impl/builtin/messages_to_query_test.go | 0 .../document/messages2query/impl/impl.go | 48 ++ .../messages2query/messages_to_query.go | 0 .../{ => document}/messages2query/options.go | 0 backend/infra/document/nl2sql/impl/impl.go | 48 ++ backend/infra/document/ocr/impl/impl.go | 55 ++ .../document/parser/impl/builtin/manager.go | 6 +- backend/infra/document/parser/impl/impl.go | 59 ++ .../parser/impl/ppstructure/manager.go | 4 +- .../progressbar/impl/progressbar/impl.go | 7 +- .../infra/document/progressbar/interface.go | 13 +- backend/infra/document/rerank/impl/impl.go | 48 ++ .../infra/document/searchstore/impl/impl.go | 433 +++++++++++ backend/infra/eventbus/eventbus.go | 2 +- backend/infra/eventbus/impl/eventbus.go | 32 + .../infra/imagex/impl/veimagex/veimagex.go | 13 + backend/infra/rdb/impl/rdb/mysql.go | 7 +- backend/infra/rdb/impl/rdb/mysql_test.go | 4 + backend/infra/sqlparser/sql_parser.go | 2 + backend/infra/storage/impl/minio/minio.go | 4 +- backend/infra/storage/impl/s3/s3.go | 2 +- backend/pkg/jsoncache/jsoncache.go | 3 +- 46 files changed, 982 insertions(+), 754 deletions(-) rename backend/{application/base/appinfra => bizpkg/buildinmodel}/builtin_chat_model.go (90%) create mode 100644 backend/bizpkg/fileutil/fileutil.go rename backend/{pkg/goutil => bizpkg/fileutil}/pyutil.go (98%) delete mode 100644 backend/crossdomain/impl/code/code.go create mode 100644 backend/infra/coderunner/impl/impl.go rename backend/infra/{ => document}/messages2query/impl/builtin/messages_to_query.go (96%) rename backend/infra/{ => document}/messages2query/impl/builtin/messages_to_query_test.go (100%) create mode 100644 backend/infra/document/messages2query/impl/impl.go rename backend/infra/{ => document}/messages2query/messages_to_query.go (100%) rename backend/infra/{ => document}/messages2query/options.go (100%) create mode 100644 backend/infra/document/nl2sql/impl/impl.go create mode 100644 backend/infra/document/ocr/impl/impl.go create mode 100644 backend/infra/document/parser/impl/impl.go create mode 100644 backend/infra/document/rerank/impl/impl.go create mode 100644 backend/infra/document/searchstore/impl/impl.go diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index 8ad5eff99..2b6fddaa5 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -83,7 +83,6 @@ import ( pluginmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock" crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user" - "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code" pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin" agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity" conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity" @@ -3837,7 +3836,7 @@ func TestNodeDebugLoop(t *testing.T) { }, nil }).AnyTimes() - code.SetCodeRunner(runner) + coderunner.SetCodeRunner(runner) id := r.load("loop_with_object_input.json") exeID := r.nodeDebug(id, "122149", withNDInput(map[string]string{"input": `[{"a":"1"},{"a":"2"}]`})) @@ -4154,7 +4153,7 @@ func TestCodeExceptionBranch(t *testing.T) { id := r.load("exception/code_exception_branch.json") mockey.PatchConvey("exception branch", func() { - code.SetCodeRunner(direct.NewRunner()) + coderunner.SetCodeRunner(direct.NewRunner()) exeID := r.testRun(id, map[string]string{"input": "hello"}) e := r.getProcess(id, exeID) @@ -4167,7 +4166,7 @@ func TestCodeExceptionBranch(t *testing.T) { mockey.PatchConvey("normal branch", func() { mockCodeRunner := mockcode.NewMockRunner(r.ctrl) - mockey.Mock(code.GetCodeRunner).Return(mockCodeRunner).Build() + mockey.Mock(coderunner.GetCodeRunner).Return(mockCodeRunner).Build() mockCodeRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(&coderunner.RunResponse{ Result: map[string]any{ "key0": "value0", @@ -4891,7 +4890,7 @@ func TestHttpImplicitDependencies(t *testing.T) { }, nil }).AnyTimes() - code.SetCodeRunner(runner) + coderunner.SetCodeRunner(runner) mockey.PatchConvey("test http node implicit dependencies", func() { input := map[string]string{ @@ -6059,7 +6058,7 @@ func TestWorkflowRunWithFiles(t *testing.T) { }, nil }).AnyTimes() - mockey.Mock(code.GetCodeRunner).Return(runner).Build() + mockey.Mock(coderunner.GetCodeRunner).Return(runner).Build() idStr := r.load("workflow_wf_file_name.json") r.publish(idStr, "v0.1.1", true) diff --git a/backend/application/application.go b/backend/application/application.go index d91db0da8..09e84d611 100644 --- a/backend/application/application.go +++ b/backend/application/application.go @@ -70,8 +70,12 @@ import ( workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow" "github.com/coze-dev/coze-studio/backend/infra/chatmodel/impl/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/checkpoint" + "github.com/coze-dev/coze-studio/backend/infra/document/progressbar" + progressBarImpl "github.com/coze-dev/coze-studio/backend/infra/document/progressbar/impl/progressbar" "github.com/coze-dev/coze-studio/backend/infra/eventbus" implEventbus "github.com/coze-dev/coze-studio/backend/infra/eventbus/impl" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" + sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" ) type eventbusImpl struct { @@ -116,6 +120,9 @@ func Init(ctx context.Context) (err error) { return err } + progressbar.New = progressBarImpl.NewProgressBar + sqlparser.New = sqlparserImpl.NewSQLParser + eventbus := initEventBus(infra) basicServices, err := initBasicServices(ctx, infra, eventbus) diff --git a/backend/application/base/appinfra/app_infra.go b/backend/application/base/appinfra/app_infra.go index 6c67146b3..2ea271697 100644 --- a/backend/application/base/appinfra/app_infra.go +++ b/backend/application/base/appinfra/app_infra.go @@ -18,65 +18,30 @@ package appinfra import ( "context" - "encoding/json" "fmt" - "net/http" "os" - "path/filepath" - "strconv" - "strings" - "time" - "google.golang.org/genai" "gorm.io/gorm" - "github.com/cloudwego/eino-ext/components/embedding/gemini" - "github.com/cloudwego/eino-ext/components/embedding/ollama" - "github.com/cloudwego/eino-ext/components/embedding/openai" - "github.com/cloudwego/eino/components/prompt" - "github.com/cloudwego/eino/schema" - "github.com/milvus-io/milvus/client/v2/milvusclient" - "github.com/volcengine/volc-sdk-golang/service/visual" - + "github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel" "github.com/coze-dev/coze-studio/backend/infra/cache" "github.com/coze-dev/coze-studio/backend/infra/cache/impl/redis" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" - "github.com/coze-dev/coze-studio/backend/infra/coderunner" - "github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/direct" - "github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/sandbox" - "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql" - builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql/impl/builtin" - "github.com/coze-dev/coze-studio/backend/infra/document/ocr" - "github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/ppocr" - "github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/veocr" - "github.com/coze-dev/coze-studio/backend/infra/document/parser" - "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/builtin" - "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/ppstructure" - "github.com/coze-dev/coze-studio/backend/infra/document/rerank" - "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/rrf" - vikingReranker "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/vikingdb" - "github.com/coze-dev/coze-studio/backend/infra/document/searchstore" - "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/elasticsearch" - "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/milvus" - "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase" - "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb" - "github.com/coze-dev/coze-studio/backend/infra/embedding" - "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark" - embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http" - "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap" + coderunner "github.com/coze-dev/coze-studio/backend/infra/coderunner/impl" + messages2query "github.com/coze-dev/coze-studio/backend/infra/document/messages2query/impl" + nl2sql "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql/impl" + ocr "github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl" + parser "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl" + rerank "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl" + searchstore "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl" "github.com/coze-dev/coze-studio/backend/infra/es/impl/es" eventbus "github.com/coze-dev/coze-studio/backend/infra/eventbus/impl" "github.com/coze-dev/coze-studio/backend/infra/idgen/impl/idgen" "github.com/coze-dev/coze-studio/backend/infra/imagex" "github.com/coze-dev/coze-studio/backend/infra/imagex/impl/veimagex" - "github.com/coze-dev/coze-studio/backend/infra/messages2query" - builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/messages2query/impl/builtin" "github.com/coze-dev/coze-studio/backend/infra/modelmgr" - oceanbaseClient "github.com/coze-dev/coze-studio/backend/infra/oceanbase" "github.com/coze-dev/coze-studio/backend/infra/orm/impl/mysql" storage "github.com/coze-dev/coze-studio/backend/infra/storage/impl" - "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" ) @@ -132,29 +97,29 @@ func Init(ctx context.Context) (*AppDependencies, error) { return nil, fmt.Errorf("init imagex client failed, err=%w", err) } - deps.ResourceEventProducer, err = initResourceEventBusProducer() + deps.ResourceEventProducer, err = eventbus.InitResourceEventBusProducer() if err != nil { return nil, fmt.Errorf("init resource event bus producer failed, err=%w", err) } - deps.AppEventProducer, err = initAppEventProducer() + deps.AppEventProducer, err = eventbus.InitAppEventProducer() if err != nil { return nil, fmt.Errorf("init app event producer failed, err=%w", err) } - deps.KnowledgeEventProducer, err = initKnowledgeEventBusProducer() + deps.KnowledgeEventProducer, err = eventbus.InitKnowledgeEventBusProducer() if err != nil { return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err) } - deps.Reranker = initReranker() + deps.Reranker = rerank.New() - deps.Rewriter, err = initRewriter(ctx) + deps.Rewriter, err = messages2query.New(ctx) if err != nil { return nil, fmt.Errorf("init rewriter failed, err=%w", err) } - deps.NL2SQL, err = initNL2SQL(ctx) + deps.NL2SQL, err = nl2sql.New(ctx) if err != nil { return nil, fmt.Errorf("init nl2sql failed, err=%w", err) } @@ -164,17 +129,12 @@ func Init(ctx context.Context) (*AppDependencies, error) { return nil, fmt.Errorf("init model manager failed, err=%w", err) } - deps.CodeRunner = initCodeRunner() + deps.CodeRunner = coderunner.New() - deps.OCR = initOCR() - - imageAnnotationModel, _, err := getBuiltinChatModel(ctx, "IA_") - if err != nil { - return nil, fmt.Errorf("get builtin chat model failed, err=%w", err) - } + deps.OCR = ocr.New() var ok bool - deps.WorkflowBuildInChatModel, ok, err = getBuiltinChatModel(ctx, "WKR_") + deps.WorkflowBuildInChatModel, ok, err = buildinmodel.GetBuiltinChatModel(ctx, "WKR_") if err != nil { return nil, fmt.Errorf("get workflow builtin chat model failed, err=%w", err) } @@ -183,12 +143,12 @@ func Init(ctx context.Context) (*AppDependencies, error) { logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured") } - deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel) + deps.ParserManager, err = parser.New(ctx, deps.TOSClient, deps.OCR) if err != nil { return nil, fmt.Errorf("init parser manager failed, err=%w", err) } - deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient) + deps.SearchStoreManagers, err = searchstore.New(ctx, deps.ESClient) if err != nil { return nil, fmt.Errorf("init search store managers failed, err=%w", err) } @@ -196,602 +156,11 @@ func Init(ctx context.Context) (*AppDependencies, error) { return deps, nil } -func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) { - // es full text search - esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es}) - - // vector search - mgr, err := getVectorStore(ctx) - if err != nil { - return nil, fmt.Errorf("init vector store failed, err=%w", err) - } - - return []searchstore.Manager{esSearchstoreManager, mgr}, nil -} - -func initReranker() rerank.Reranker { - rerankerType := os.Getenv("RERANK_TYPE") - switch rerankerType { - case "vikingdb": - return vikingReranker.NewReranker(getVikingRerankerConfig()) - case "rrf": - return rrf.NewRRFReranker(0) - default: - return rrf.NewRRFReranker(0) - } -} -func getVikingRerankerConfig() *vikingReranker.Config { - return &vikingReranker.Config{ - AK: os.Getenv("VIKINGDB_RERANK_AK"), - SK: os.Getenv("VIKINGDB_RERANK_SK"), - Domain: os.Getenv("VIKINGDB_RERANK_HOST"), - Region: os.Getenv("VIKINGDB_RERANK_REGION"), - Model: os.Getenv("VIKINGDB_RERANK_MODEL"), - } -} -func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) { - rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_") - if err != nil { - return nil, err - } - - filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json") - rewriterTemplate, err := readJinja2PromptTemplate(filePath) - if err != nil { - return nil, err - } - - rewriter, err := builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate) - if err != nil { - return nil, err - } - - return rewriter, nil -} - -func getWorkingDirectory() string { - root, err := os.Getwd() - if err != nil { - logs.Warnf("[InitConfig] Failed to get current working directory: %v", err) - root = os.Getenv("PWD") - } - return root -} - -func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) { - b, err := os.ReadFile(jsonFilePath) - if err != nil { - return nil, err - } - var m2qMessages []*schema.Message - if err = json.Unmarshal(b, &m2qMessages); err != nil { - return nil, err - } - tpl := make([]schema.MessagesTemplate, len(m2qMessages)) - for i := range m2qMessages { - tpl[i] = m2qMessages[i] - } - return prompt.FromMessages(schema.Jinja2, tpl...), nil -} - -func initNL2SQL(ctx context.Context) (nl2sql.NL2SQL, error) { - n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_") - if err != nil { - return nil, err - } - - filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json") - n2sTemplate, err := readJinja2PromptTemplate(filePath) - if err != nil { - return nil, err - } - - n2s, err := builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate) - if err != nil { - return nil, err - } - - return n2s, nil -} - func initImageX(ctx context.Context) (imagex.ImageX, error) { uploadComponentType := os.Getenv(consts.FileUploadComponentType) if uploadComponentType != consts.FileUploadComponentTypeImagex { return storage.NewImagex(ctx) } - return veimagex.New( - os.Getenv(consts.VeImageXAK), - os.Getenv(consts.VeImageXSK), - os.Getenv(consts.VeImageXDomain), - os.Getenv(consts.VeImageXUploadHost), - os.Getenv(consts.VeImageXTemplate), - []string{os.Getenv(consts.VeImageXServerID)}, - ) -} - -func initResourceEventBusProducer() (eventbus.Producer, error) { - nameServer := os.Getenv(consts.MQServer) - resourceEventBusProducer, err := eventbus.NewProducer(nameServer, - consts.RMQTopicResource, consts.RMQConsumeGroupResource, 1) - if err != nil { - return nil, fmt.Errorf("init resource producer failed, err=%w", err) - } - - return resourceEventBusProducer, nil -} - -func initAppEventProducer() (eventbus.Producer, error) { - nameServer := os.Getenv(consts.MQServer) - appEventProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, 1) - if err != nil { - return nil, fmt.Errorf("init app producer failed, err=%w", err) - } - - return appEventProducer, nil -} - -func initKnowledgeEventBusProducer() (eventbus.Producer, error) { - nameServer := os.Getenv(consts.MQServer) - - knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2) - if err != nil { - return nil, fmt.Errorf("init knowledge producer failed, err=%w", err) - } - - return knowledgeProducer, nil -} - -func initCodeRunner() coderunner.Runner { - switch typ := os.Getenv(consts.CodeRunnerType); typ { - case "sandbox": - getAndSplit := func(key string) []string { - v := os.Getenv(key) - if v == "" { - return nil - } - return strings.Split(v, ",") - } - config := &sandbox.Config{ - AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), - AllowRead: getAndSplit(consts.CodeRunnerAllowRead), - AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), - AllowNet: getAndSplit(consts.CodeRunnerAllowNet), - AllowRun: getAndSplit(consts.CodeRunnerAllowRun), - AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), - NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir), - TimeoutSeconds: 0, - MemoryLimitMB: 0, - } - if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil { - config.TimeoutSeconds = f - } else { - config.TimeoutSeconds = 60.0 - } - if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil { - config.MemoryLimitMB = mem - } else { - config.MemoryLimitMB = 100 - } - return sandbox.NewRunner(config) - default: - return direct.NewRunner() - } -} - -func initOCR() ocr.OCR { - var ocr ocr.OCR - switch os.Getenv(consts.OCRType) { - case "ve": - ocrAK := os.Getenv(consts.VeOCRAK) - ocrSK := os.Getenv(consts.VeOCRSK) - if ocrAK == "" || ocrSK == "" { - logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well") - } - inst := visual.NewInstance() - inst.Client.SetAccessKey(ocrAK) - inst.Client.SetSecretKey(ocrSK) - ocr = veocr.NewOCR(&veocr.Config{Client: inst}) - case "paddleocr": - url := os.Getenv(consts.PPOCRAPIURL) - client := &http.Client{} - ocr = ppocr.NewOCR(&ppocr.Config{Client: client, URL: url}) - default: - // accept ocr not configured - } - - return ocr -} - -func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) { - var parserManager parser.Manager - parserType := os.Getenv(consts.ParserType) - switch parserType { - case "builtin", "": - parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel) - case "paddleocr": - url := os.Getenv(consts.PPStructureAPIURL) - client := &http.Client{} - apiConfig := &ppstructure.APIConfig{ - Client: client, - URL: url, - } - parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel) - default: - return nil, fmt.Errorf("parser type %s not supported", parserType) - } - - return parserManager, nil -} - -func getVectorStore(ctx context.Context) (searchstore.Manager, error) { - vsType := os.Getenv("VECTOR_STORE_TYPE") - - switch vsType { - case "milvus": - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - var ( - milvusAddr = os.Getenv("MILVUS_ADDR") - user = os.Getenv("MILVUS_USER") - password = os.Getenv("MILVUS_PASSWORD") - milvusToken = os.Getenv("MILVUS_TOKEN") - ) - mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ - Address: milvusAddr, - Username: user, - Password: password, - APIKey: milvusToken, - }) - if err != nil { - return nil, fmt.Errorf("init milvus client failed, err=%w", err) - } - - emb, err := getEmbedding(ctx) - if err != nil { - return nil, fmt.Errorf("init milvus embedding failed, err=%w", err) - } - - mgr, err := milvus.NewManager(&milvus.ManagerConfig{ - Client: mc, - Embedding: emb, - EnableHybrid: ptr.Of(true), - }) - if err != nil { - return nil, fmt.Errorf("init milvus vector store failed, err=%w", err) - } - - return mgr, nil - case "vikingdb": - var ( - host = os.Getenv("VIKING_DB_HOST") - region = os.Getenv("VIKING_DB_REGION") - ak = os.Getenv("VIKING_DB_AK") - sk = os.Getenv("VIKING_DB_SK") - scheme = os.Getenv("VIKING_DB_SCHEME") - modelName = os.Getenv("VIKING_DB_MODEL_NAME") - ) - if ak == "" || sk == "" { - return nil, fmt.Errorf("invalid vikingdb ak / sk") - } - if host == "" { - host = "api-vikingdb.volces.com" - } - if region == "" { - region = "cn-beijing" - } - if scheme == "" { - scheme = "https" - } - - var embConfig *vikingdb.VikingEmbeddingConfig - if modelName != "" { - embName := vikingdb.VikingEmbeddingModelName(modelName) - if embName.Dimensions() == 0 { - return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName) - } - embConfig = &vikingdb.VikingEmbeddingConfig{ - UseVikingEmbedding: true, - EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse, - ModelName: embName, - ModelVersion: embName.ModelVersion(), - DenseWeight: ptr.Of(0.2), - BuiltinEmbedding: nil, - } - } else { - builtinEmbedding, err := getEmbedding(ctx) - if err != nil { - return nil, fmt.Errorf("builtint embedding init failed, err=%w", err) - } - - embConfig = &vikingdb.VikingEmbeddingConfig{ - UseVikingEmbedding: false, - EnableHybrid: false, - BuiltinEmbedding: builtinEmbedding, - } - } - - svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme) - mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{ - Service: svc, - IndexingConfig: nil, // use default config - EmbeddingConfig: embConfig, - }) - if err != nil { - return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err) - } - - return mgr, nil - - case "oceanbase": - emb, err := getEmbedding(ctx) - if err != nil { - return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err) - } - - var ( - host = os.Getenv("OCEANBASE_HOST") - port = os.Getenv("OCEANBASE_PORT") - user = os.Getenv("OCEANBASE_USER") - password = os.Getenv("OCEANBASE_PASSWORD") - database = os.Getenv("OCEANBASE_DATABASE") - ) - if host == "" || port == "" || user == "" || password == "" || database == "" { - return nil, fmt.Errorf("invalid oceanbase configuration: host, port, user, password, database are required") - } - - dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", - user, password, host, port, database) - - client, err := oceanbaseClient.NewOceanBaseClient(dsn) - if err != nil { - return nil, fmt.Errorf("init oceanbase client failed, err=%w", err) - } - - if err := client.InitDatabase(ctx); err != nil { - return nil, fmt.Errorf("init oceanbase database failed, err=%w", err) - } - - // Get configuration from environment variables with defaults - batchSize := 100 - if bs := os.Getenv("OCEANBASE_BATCH_SIZE"); bs != "" { - if bsInt, err := strconv.Atoi(bs); err == nil { - batchSize = bsInt - } - } - - enableCache := true - if ec := os.Getenv("OCEANBASE_ENABLE_CACHE"); ec != "" { - if ecBool, err := strconv.ParseBool(ec); err == nil { - enableCache = ecBool - } - } - - cacheTTL := 300 * time.Second - if ct := os.Getenv("OCEANBASE_CACHE_TTL"); ct != "" { - if ctInt, err := strconv.Atoi(ct); err == nil { - cacheTTL = time.Duration(ctInt) * time.Second - } - } - - maxConnections := 100 - if mc := os.Getenv("OCEANBASE_MAX_CONNECTIONS"); mc != "" { - if mcInt, err := strconv.Atoi(mc); err == nil { - maxConnections = mcInt - } - } - - connTimeout := 30 * time.Second - if ct := os.Getenv("OCEANBASE_CONN_TIMEOUT"); ct != "" { - if ctInt, err := strconv.Atoi(ct); err == nil { - connTimeout = time.Duration(ctInt) * time.Second - } - } - - managerConfig := &oceanbase.ManagerConfig{ - Client: client, - Embedding: emb, - BatchSize: batchSize, - EnableCache: enableCache, - CacheTTL: cacheTTL, - MaxConnections: maxConnections, - ConnTimeout: connTimeout, - } - mgr, err := oceanbase.NewManager(managerConfig) - if err != nil { - return nil, fmt.Errorf("init oceanbase vector store failed, err=%w", err) - } - return mgr, nil - - default: - return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType) - } -} - -func getEmbedding(ctx context.Context) (embedding.Embedder, error) { - var batchSize int - if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil { - logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100") - batchSize = 100 - } else { - batchSize = int(bs) - } - - var emb embedding.Embedder - - switch os.Getenv("EMBEDDING_TYPE") { - case "openai": - var ( - openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL") - openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL") - openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY") - openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE") - openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION") - openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS") - openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS") - ) - - byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure) - if err != nil { - return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err) - } - - dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err) - } - - openAICfg := &openai.EmbeddingConfig{ - APIKey: openAIEmbeddingApiKey, - ByAzure: byAzure, - BaseURL: openAIEmbeddingBaseURL, - APIVersion: openAIEmbeddingApiVersion, - Model: openAIEmbeddingModel, - // Dimensions: ptr.Of(int(dims)), - } - reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0) - if reqDims > 0 { - // some openai model not support request dims - openAICfg.Dimensions = ptr.Of(int(reqDims)) - } - - emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init openai embedding failed, err=%w", err) - } - - case "ark": - var ( - arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL") - arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL") - arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY") - // deprecated: use ARK_EMBEDDING_API_KEY instead - // ARK_EMBEDDING_AK will be removed in the future - arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK") - arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS") - arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE") - ) - - dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err) - } - - apiType := ark.APITypeText - if arkEmbeddingAPIType != "" { - if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal { - return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t) - } else { - apiType = t - } - } - - emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{ - APIKey: func() string { - if arkEmbeddingApiKey != "" { - return arkEmbeddingApiKey - } - return arkEmbeddingAK - }(), - Model: arkEmbeddingModel, - BaseURL: arkEmbeddingBaseURL, - APIType: &apiType, - }, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) - } - - case "ollama": - var ( - ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL") - ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL") - ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS") - ) - - dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err) - } - - emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{ - BaseURL: ollamaEmbeddingBaseURL, - Model: ollamaEmbeddingModel, - }, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) - } - case "gemini": - var ( - geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL") - geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL") - geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY") - geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS") - geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI - geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT") - geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION") - ) - - if len(geminiEmbeddingModel) == 0 { - return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required") - } - if len(geminiEmbeddingApiKey) == 0 { - return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required") - } - if len(geminiEmbeddingDims) == 0 { - return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required") - } - if len(geminiEmbeddingBackend) == 0 { - return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required") - } - - dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64) - if convErr != nil { - return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr) - } - - backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64) - if convErr != nil { - return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr) - } - - geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{ - APIKey: geminiEmbeddingApiKey, - Backend: genai.Backend(backend), - Project: geminiEmbeddingProject, - Location: geminiEmbeddingLocation, - HTTPOptions: genai.HTTPOptions{ - BaseURL: geminiEmbeddingBaseURL, - }, - }) - if err != nil { - return nil, fmt.Errorf("init gemini client failed, err=%w", err) - } - - emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{ - Client: geminiCli, - Model: geminiEmbeddingModel, - OutputDimensionality: ptr.Of(int32(dims)), - }, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init gemini embedding failed, err=%w", err) - } - case "http": - var ( - httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR") - httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS") - ) - dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init http embedding dims failed, err=%w", err) - } - emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init http embedding failed, err=%w", err) - } - - default: - return nil, fmt.Errorf("init knowledge embedding failed, type not configured") - } - - return emb, nil + + return veimagex.NewDefault() } diff --git a/backend/application/base/appinfra/modelmgr.go b/backend/application/base/appinfra/modelmgr.go index a14cbc3e5..af6786153 100644 --- a/backend/application/base/appinfra/modelmgr.go +++ b/backend/application/base/appinfra/modelmgr.go @@ -31,6 +31,8 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/logs" ) +// TODO(fanlv) : 模型管理移到 Infra + func initModelMgr() (modelmgr.Manager, error) { wd, err := os.Getwd() if err != nil { diff --git a/backend/application/workflow/init.go b/backend/application/workflow/init.go index 9d1f843a5..168b04939 100644 --- a/backend/application/workflow/init.go +++ b/backend/application/workflow/init.go @@ -28,8 +28,6 @@ import ( "github.com/cloudwego/eino/compose" "gorm.io/gorm" - "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code" - knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service" dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service" variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service" @@ -40,7 +38,6 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/config" "github.com/coze-dev/coze-studio/backend/domain/workflow/service" - workflowservice "github.com/coze-dev/coze-studio/backend/domain/workflow/service" "github.com/coze-dev/coze-studio/backend/infra/cache" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/coderunner" @@ -101,8 +98,8 @@ func InitService(_ context.Context, components *ServiceComponents) (*Application workflowDomainSVC := service.NewWorkflowService(workflowRepo) wrapPlugin.SetOSS(components.Tos) - code.SetCodeRunner(components.CodeRunner) - callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler()) + coderunner.SetCodeRunner(components.CodeRunner) + callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler()) setEventBus(components.DomainNotifier) diff --git a/backend/application/base/appinfra/builtin_chat_model.go b/backend/bizpkg/buildinmodel/builtin_chat_model.go similarity index 90% rename from backend/application/base/appinfra/builtin_chat_model.go rename to backend/bizpkg/buildinmodel/builtin_chat_model.go index 1831f8fc2..68d89a69d 100644 --- a/backend/application/base/appinfra/builtin_chat_model.go +++ b/backend/bizpkg/buildinmodel/builtin_chat_model.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package appinfra +package buildinmodel import ( "context" @@ -22,18 +22,18 @@ import ( "os" "strconv" - ao "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/cloudwego/eino-ext/components/model/ark" "github.com/cloudwego/eino-ext/components/model/deepseek" "github.com/cloudwego/eino-ext/components/model/gemini" "github.com/cloudwego/eino-ext/components/model/ollama" - mo "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino-ext/components/model/qwen" "google.golang.org/genai" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" ) -func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) { +func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) { getEnv := func(key string) string { if val := os.Getenv(envPrefix + key); val != "" { return val @@ -44,14 +44,14 @@ func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.B switch getEnv("BUILTIN_CM_TYPE") { case "openai": byAzure, _ := strconv.ParseBool(getEnv("BUILTIN_CM_OPENAI_BY_AZURE")) - bcm, err = mo.NewChatModel(ctx, &mo.ChatModelConfig{ + bcm, err = openai.NewChatModel(ctx, &openai.ChatModelConfig{ APIKey: getEnv("BUILTIN_CM_OPENAI_API_KEY"), ByAzure: byAzure, BaseURL: getEnv("BUILTIN_CM_OPENAI_BASE_URL"), Model: getEnv("BUILTIN_CM_OPENAI_MODEL"), }) case "ark": - bcm, err = ao.NewChatModel(ctx, &ao.ChatModelConfig{ + bcm, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{ APIKey: getEnv("BUILTIN_CM_ARK_API_KEY"), Model: getEnv("BUILTIN_CM_ARK_MODEL"), BaseURL: getEnv("BUILTIN_CM_ARK_BASE_URL"), diff --git a/backend/bizpkg/fileutil/fileutil.go b/backend/bizpkg/fileutil/fileutil.go new file mode 100644 index 000000000..4ff5838d3 --- /dev/null +++ b/backend/bizpkg/fileutil/fileutil.go @@ -0,0 +1,51 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package fileutil + +import ( + "encoding/json" + "os" + + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/schema" + + "github.com/coze-dev/coze-studio/backend/pkg/logs" +) + +func GetWorkingDirectory() string { + root, err := os.Getwd() + if err != nil { + logs.Warnf("[InitConfig] Failed to get current working directory: %v", err) + root = os.Getenv("PWD") + } + return root +} + +func ReadJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) { + b, err := os.ReadFile(jsonFilePath) + if err != nil { + return nil, err + } + var m2qMessages []*schema.Message + if err = json.Unmarshal(b, &m2qMessages); err != nil { + return nil, err + } + tpl := make([]schema.MessagesTemplate, len(m2qMessages)) + for i := range m2qMessages { + tpl[i] = m2qMessages[i] + } + return prompt.FromMessages(schema.Jinja2, tpl...), nil +} diff --git a/backend/pkg/goutil/pyutil.go b/backend/bizpkg/fileutil/pyutil.go similarity index 98% rename from backend/pkg/goutil/pyutil.go rename to backend/bizpkg/fileutil/pyutil.go index 9423454af..b75e8a292 100644 --- a/backend/pkg/goutil/pyutil.go +++ b/backend/bizpkg/fileutil/pyutil.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package goutil +package fileutil import ( "os" diff --git a/backend/crossdomain/impl/code/code.go b/backend/crossdomain/impl/code/code.go deleted file mode 100644 index f77471cf7..000000000 --- a/backend/crossdomain/impl/code/code.go +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package code - -import ( - "github.com/coze-dev/coze-studio/backend/infra/coderunner" -) - -func GetCodeRunner() coderunner.Runner { - return runnerImpl -} - -func SetCodeRunner(runner coderunner.Runner) { - runnerImpl = runner -} - -var runnerImpl coderunner.Runner diff --git a/backend/domain/agent/singleagent/internal/agentflow/node_tool_database.go b/backend/domain/agent/singleagent/internal/agentflow/node_tool_database.go index e1020b64a..f65cbdd4a 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/node_tool_database.go +++ b/backend/domain/agent/singleagent/internal/agentflow/node_tool_database.go @@ -35,7 +35,7 @@ import ( "github.com/coze-dev/coze-studio/backend/api/model/data/database/table" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database" "github.com/coze-dev/coze-studio/backend/domain/memory/database/service" - "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" ) @@ -79,8 +79,7 @@ func (d *databaseTool) Invoke(ctx context.Context, req ExecuteSQLRequest) (strin tableType = table.TableType_DraftTable } - // TODO(@fanlv): domain 不能依赖具体 impl - tableName, err := sqlparser.NewSQLParser().GetTableName(req.SQL) + tableName, err := sqlparser.New().GetTableName(req.SQL) if err != nil { return "", err } diff --git a/backend/domain/knowledge/service/event_handle.go b/backend/domain/knowledge/service/event_handle.go index 7203b2e98..494c3799f 100644 --- a/backend/domain/knowledge/service/event_handle.go +++ b/backend/domain/knowledge/service/event_handle.go @@ -35,8 +35,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/events" "github.com/coze-dev/coze-studio/backend/infra/document" - progressbarContract "github.com/coze-dev/coze-studio/backend/infra/document/progressbar" - "github.com/coze-dev/coze-studio/backend/infra/document/progressbar/impl/progressbar" + "github.com/coze-dev/coze-studio/backend/infra/document/progressbar" "github.com/coze-dev/coze-studio/backend/infra/document/searchstore" "github.com/coze-dev/coze-studio/backend/infra/eventbus" "github.com/coze-dev/coze-studio/backend/infra/rdb" @@ -379,9 +378,8 @@ func (k *knowledgeSVC) handleTableDocument(ctx context.Context, func (k *knowledgeSVC) processDocumentChunks(ctx context.Context, doc *entity.Document, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord) error { - // TODO(@fanlv): domain 不能依赖具体 impl batchSize := 100 - progressbar := progressbar.NewProgressBar(ctx, doc.ID, + progressbar := progressbar.New(ctx, doc.ID, int64(len(parseResult)*len(k.searchStoreManagers)), k.cacheCli, true) if err := progressbar.AddN(int(cacheRecord.LastProcessedNumber) * len(k.searchStoreManagers)); err != nil { @@ -417,7 +415,7 @@ func (k *knowledgeSVC) finalizeDocumentIndexing(ctx context.Context, knowledgeID // batchProcessSlice processes a batch of document slices func (k *knowledgeSVC) batchProcessSlice(ctx context.Context, doc *entity.Document, startIdx int, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord, - progressBar progressbarContract.ProgressBar) error { + progressBar progressbar.ProgressBar) error { collectionName := getCollectionName(doc.KnowledgeID) length := len(parseResult) @@ -640,7 +638,7 @@ func (k *knowledgeSVC) storeSlicesInDB(ctx context.Context, doc *entity.Document // indexSlicesInSearchStores indexes slices in appropriate search stores func (k *knowledgeSVC) indexSlicesInSearchStores(ctx context.Context, doc *entity.Document, collectionName string, sliceEntities []*entity.Slice, cacheRecord *indexDocCacheRecord, - progressBar progressbarContract.ProgressBar) error { + progressBar progressbar.ProgressBar) error { fields, err := k.mapSearchFields(doc) if err != nil { diff --git a/backend/domain/knowledge/service/knowledge.go b/backend/domain/knowledge/service/knowledge.go index fb16a3a48..6bef912b9 100644 --- a/backend/domain/knowledge/service/knowledge.go +++ b/backend/domain/knowledge/service/knowledge.go @@ -47,15 +47,15 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/knowledge/repository" "github.com/coze-dev/coze-studio/backend/infra/cache" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" + "github.com/coze-dev/coze-studio/backend/infra/document/messages2query" "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql" "github.com/coze-dev/coze-studio/backend/infra/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/document/parser" - "github.com/coze-dev/coze-studio/backend/infra/document/progressbar/impl/progressbar" + "github.com/coze-dev/coze-studio/backend/infra/document/progressbar" "github.com/coze-dev/coze-studio/backend/infra/document/rerank" "github.com/coze-dev/coze-studio/backend/infra/document/searchstore" "github.com/coze-dev/coze-studio/backend/infra/eventbus" "github.com/coze-dev/coze-studio/backend/infra/idgen" - "github.com/coze-dev/coze-studio/backend/infra/messages2query" "github.com/coze-dev/coze-studio/backend/infra/rdb" rdbEntity "github.com/coze-dev/coze-studio/backend/infra/rdb/entity" "github.com/coze-dev/coze-studio/backend/infra/storage" @@ -552,8 +552,7 @@ func (k *knowledgeSVC) MGetDocumentProgress(ctx context.Context, request *MGetDo } func (k *knowledgeSVC) getProgressFromCache(ctx context.Context, documentProgress *DocumentProgress) (err error) { - // TODO(@fanlv) : domain 不依赖 impl - progressBar := progressbar.NewProgressBar(ctx, documentProgress.ID, 0, k.cacheCli, false) + progressBar := progressbar.New(ctx, documentProgress.ID, 0, k.cacheCli, false) percent, remainSec, errMsg := progressBar.GetProgress(ctx) documentProgress.Progress = int(percent) documentProgress.RemainingSec = int64(remainSec) diff --git a/backend/domain/knowledge/service/retrieve.go b/backend/domain/knowledge/service/retrieve.go index aff1a2838..bb78c4d0f 100644 --- a/backend/domain/knowledge/service/retrieve.go +++ b/backend/domain/knowledge/service/retrieve.go @@ -36,13 +36,13 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/document" + "github.com/coze-dev/coze-studio/backend/infra/document/messages2query" "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql" "github.com/coze-dev/coze-studio/backend/infra/document/rerank" "github.com/coze-dev/coze-studio/backend/infra/document/searchstore" - "github.com/coze-dev/coze-studio/backend/infra/messages2query" "github.com/coze-dev/coze-studio/backend/infra/rdb" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/sqlparser" - "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/sets" @@ -395,8 +395,7 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i] } - // TODO(@fanlv) : domain 不依赖 impl - parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap) + parsedSQL, err := sqlparser.New().ParseAndModifySQL(sql, replaceMap) if err != nil { logs.CtxErrorf(ctx, "parse sql failed: %v", err) return nil, err @@ -466,7 +465,7 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum const pkID = "_knowledge_slice_id" func addSliceIdColumn(originalSql string) string { - sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID}) + sql, err := sqlparser.New().AddSelectFieldsToSelectSQL(originalSql, []string{pkID}) if err != nil { logs.Errorf("add slice id column failed: %v", err) return originalSql diff --git a/backend/domain/knowledge/service/retrieve_test.go b/backend/domain/knowledge/service/retrieve_test.go index 01e727ea0..c6d7b5f21 100644 --- a/backend/domain/knowledge/service/retrieve_test.go +++ b/backend/domain/knowledge/service/retrieve_test.go @@ -36,6 +36,8 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql" "github.com/coze-dev/coze-studio/backend/infra/rdb" rdb_entity "github.com/coze-dev/coze-studio/backend/infra/rdb/entity" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" + sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/nl2sql_mock" mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/rdb" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -59,6 +61,8 @@ func TestAddSliceIdColumn(t *testing.T) { expected: "SELECT FROM users", }, } + sqlparser.New = sqlparserImpl.NewSQLParser + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual := addSliceIdColumn(tt.input) diff --git a/backend/domain/memory/database/service/database_impl.go b/backend/domain/memory/database/service/database_impl.go index 3e6473adc..e18ce8a65 100644 --- a/backend/domain/memory/database/service/database_impl.go +++ b/backend/domain/memory/database/service/database_impl.go @@ -46,8 +46,8 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/idgen" "github.com/coze-dev/coze-studio/backend/infra/rdb" entity3 "github.com/coze-dev/coze-studio/backend/infra/rdb/entity" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/sqlparser" - "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" "github.com/coze-dev/coze-studio/backend/infra/storage" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -1044,8 +1044,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe return nil, fmt.Errorf("SQL is empty") } - // TODO(@fanlv) : domain 不依赖 impl - operation, err := sqlparser.NewSQLParser().GetSQLOperation(*req.SQL) + operation, err := sqlparser.New().GetSQLOperation(*req.SQL) if err != nil { return nil, err } @@ -1072,7 +1071,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe }, } - parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(*req.SQL, tableColumnMapping) + parsedSQL, err := sqlparser.New().ParseAndModifySQL(*req.SQL, tableColumnMapping) if err != nil { return nil, fmt.Errorf("parse sql failed: %v", err) } @@ -1080,7 +1079,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 { switch operation { case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete: - parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID)) + parsedSQL, err = sqlparser.New().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID)) if err != nil { return nil, fmt.Errorf("append sql filter failed: %v", err) } @@ -1092,7 +1091,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe if req.ConnectorID != nil { cid = *req.ConnectorID } - nums, err := sqlparser.NewSQLParser().GetInsertDataNums(parsedSQL) + nums, err := sqlparser.New().GetInsertDataNums(parsedSQL) if err != nil { return nil, err } @@ -1114,7 +1113,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe for i, id := range ids { iIDs[i] = id } - parsedSQL, _, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{ + parsedSQL, _, err = sqlparser.New().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{ { ColName: database.DefaultCidColName, Value: cid, @@ -1128,7 +1127,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe return nil, fmt.Errorf("add columns to insert sql failed: %v", err) } } else if req.SQLType == database.SQLType_Parameterized { - parsedSQL, existingCols, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{ + parsedSQL, existingCols, err = sqlparser.New().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{ { ColName: database.DefaultCidColName, }, diff --git a/backend/domain/memory/database/service/database_impl_test.go b/backend/domain/memory/database/service/database_impl_test.go index 155db1097..98346a4fe 100644 --- a/backend/domain/memory/database/service/database_impl_test.go +++ b/backend/domain/memory/database/service/database_impl_test.go @@ -37,6 +37,8 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/memory/database/repository" "github.com/coze-dev/coze-studio/backend/infra/rdb" rdb2 "github.com/coze-dev/coze-studio/backend/infra/rdb/impl/rdb" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" + sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/idgen" storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/storage" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -44,6 +46,8 @@ import ( ) func setupTestEnv(t *testing.T) (*gorm.DB, rdb.RDB, *mock.MockIDGenerator, repository.DraftDAO, repository.OnlineDAO, Database) { + sqlparser.New = sqlparserImpl.NewSQLParser + dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local" if os.Getenv("CI_JOB_NAME") != "" { dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql") diff --git a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go index c6d2c7fd1..171d38d8a 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go +++ b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go @@ -45,7 +45,6 @@ import ( mockmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr/modelmock" crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock" - "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code" userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" @@ -768,7 +767,7 @@ func TestCodeAndPluginNodes(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() mockCodeRunner := mockcode.NewMockRunner(ctrl) - mockey.Mock(code.GetCodeRunner).Return(mockCodeRunner).Build() + mockey.Mock(coderunner.GetCodeRunner).Return(mockCodeRunner).Build() mockRepo := mockWorkflow.NewMockRepository(ctrl) diff --git a/backend/domain/workflow/internal/nodes/code/code.go b/backend/domain/workflow/internal/nodes/code/code.go index ffb399293..f2deadd03 100644 --- a/backend/domain/workflow/internal/nodes/code/code.go +++ b/backend/domain/workflow/internal/nodes/code/code.go @@ -25,7 +25,6 @@ import ( "golang.org/x/exp/maps" - code2 "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code" wf "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" @@ -169,7 +168,7 @@ func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.Bui code: c.Code, language: c.Language, outputConfig: ns.OutputTypes, - runner: code2.GetCodeRunner(), + runner: coderunner.GetCodeRunner(), importError: importErr, }, nil } diff --git a/backend/infra/cache/impl/redis/redis.go b/backend/infra/cache/impl/redis/redis.go index 9b0628029..4844152e7 100644 --- a/backend/infra/cache/impl/redis/redis.go +++ b/backend/infra/cache/impl/redis/redis.go @@ -26,6 +26,8 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/cache" ) +type Cmdable = cache.Cmdable + func New() cache.Cmdable { addr := os.Getenv("REDIS_ADDR") password := os.Getenv("REDIS_PASSWORD") diff --git a/backend/infra/coderunner/code.go b/backend/infra/coderunner/code.go index c62cbb1b6..bcbdacf5e 100644 --- a/backend/infra/coderunner/code.go +++ b/backend/infra/coderunner/code.go @@ -38,3 +38,13 @@ type RunResponse struct { type Runner interface { Run(ctx context.Context, request *RunRequest) (*RunResponse, error) } + +func GetCodeRunner() Runner { + return runnerImpl +} + +func SetCodeRunner(runner Runner) { + runnerImpl = runner +} + +var runnerImpl Runner diff --git a/backend/infra/coderunner/impl/direct/runner.go b/backend/infra/coderunner/impl/direct/runner.go index e7c9ad93a..a85c38750 100644 --- a/backend/infra/coderunner/impl/direct/runner.go +++ b/backend/infra/coderunner/impl/direct/runner.go @@ -22,8 +22,8 @@ import ( "fmt" "os/exec" + "github.com/coze-dev/coze-studio/backend/bizpkg/fileutil" "github.com/coze-dev/coze-studio/backend/infra/coderunner" - "github.com/coze-dev/coze-studio/backend/pkg/goutil" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -78,7 +78,7 @@ func (r *runner) pythonCmdRun(_ context.Context, code string, params map[string] if err != nil { return nil, fmt.Errorf("failed to marshal params to json, err: %w", err) } - cmd := exec.Command(goutil.GetPython3Path(), "-c", fmt.Sprintf(pythonCode, code), string(bs)) // ignore_security_alert RCE + cmd := exec.Command(fileutil.GetPython3Path(), "-c", fmt.Sprintf(pythonCode, code), string(bs)) // ignore_security_alert RCE stdout := new(bytes.Buffer) stderr := new(bytes.Buffer) cmd.Stdout = stdout diff --git a/backend/infra/coderunner/impl/impl.go b/backend/infra/coderunner/impl/impl.go new file mode 100644 index 000000000..b53cf5242 --- /dev/null +++ b/backend/infra/coderunner/impl/impl.go @@ -0,0 +1,66 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "os" + "strconv" + "strings" + + "github.com/coze-dev/coze-studio/backend/infra/coderunner" + "github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/direct" + "github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/sandbox" + "github.com/coze-dev/coze-studio/backend/types/consts" +) + +type Runner = coderunner.Runner + +func New() Runner { + switch typ := os.Getenv(consts.CodeRunnerType); typ { + case "sandbox": + getAndSplit := func(key string) []string { + v := os.Getenv(key) + if v == "" { + return nil + } + return strings.Split(v, ",") + } + config := &sandbox.Config{ + AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), + AllowRead: getAndSplit(consts.CodeRunnerAllowRead), + AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), + AllowNet: getAndSplit(consts.CodeRunnerAllowNet), + AllowRun: getAndSplit(consts.CodeRunnerAllowRun), + AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), + NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir), + TimeoutSeconds: 0, + MemoryLimitMB: 0, + } + if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil { + config.TimeoutSeconds = f + } else { + config.TimeoutSeconds = 60.0 + } + if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil { + config.MemoryLimitMB = mem + } else { + config.MemoryLimitMB = 100 + } + return sandbox.NewRunner(config) + default: + return direct.NewRunner() + } +} diff --git a/backend/infra/coderunner/impl/sandbox/runner.go b/backend/infra/coderunner/impl/sandbox/runner.go index 4b71ce704..3395f46a9 100644 --- a/backend/infra/coderunner/impl/sandbox/runner.go +++ b/backend/infra/coderunner/impl/sandbox/runner.go @@ -23,15 +23,15 @@ import ( "os" "os/exec" + "github.com/coze-dev/coze-studio/backend/bizpkg/fileutil" "github.com/coze-dev/coze-studio/backend/infra/coderunner" - "github.com/coze-dev/coze-studio/backend/pkg/goutil" "github.com/coze-dev/coze-studio/backend/pkg/logs" ) func NewRunner(config *Config) coderunner.Runner { return &runner{ - pyPath: goutil.GetPython3Path(), - scriptPath: goutil.GetPythonFilePath("sandbox.py"), + pyPath: fileutil.GetPython3Path(), + scriptPath: fileutil.GetPythonFilePath("sandbox.py"), config: config, } } diff --git a/backend/infra/messages2query/impl/builtin/messages_to_query.go b/backend/infra/document/messages2query/impl/builtin/messages_to_query.go similarity index 96% rename from backend/infra/messages2query/impl/builtin/messages_to_query.go rename to backend/infra/document/messages2query/impl/builtin/messages_to_query.go index 2bae5dd4f..40664df49 100644 --- a/backend/infra/messages2query/impl/builtin/messages_to_query.go +++ b/backend/infra/document/messages2query/impl/builtin/messages_to_query.go @@ -26,7 +26,7 @@ import ( "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" - "github.com/coze-dev/coze-studio/backend/infra/messages2query" + "github.com/coze-dev/coze-studio/backend/infra/document/messages2query" ) func NewMessagesToQuery(_ context.Context, model chatmodel.BaseChatModel, template prompt.ChatTemplate) (messages2query.MessagesToQuery, error) { diff --git a/backend/infra/messages2query/impl/builtin/messages_to_query_test.go b/backend/infra/document/messages2query/impl/builtin/messages_to_query_test.go similarity index 100% rename from backend/infra/messages2query/impl/builtin/messages_to_query_test.go rename to backend/infra/document/messages2query/impl/builtin/messages_to_query_test.go diff --git a/backend/infra/document/messages2query/impl/impl.go b/backend/infra/document/messages2query/impl/impl.go new file mode 100644 index 000000000..8e6a392f1 --- /dev/null +++ b/backend/infra/document/messages2query/impl/impl.go @@ -0,0 +1,48 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "context" + "path/filepath" + + "github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel" + "github.com/coze-dev/coze-studio/backend/bizpkg/fileutil" + "github.com/coze-dev/coze-studio/backend/infra/document/messages2query" + "github.com/coze-dev/coze-studio/backend/infra/document/messages2query/impl/builtin" +) + +type MessagesToQuery = messages2query.MessagesToQuery + +func New(ctx context.Context) (MessagesToQuery, error) { + rewriterChatModel, _, err := buildinmodel.GetBuiltinChatModel(ctx, "M2Q_") + if err != nil { + return nil, err + } + + filePath := filepath.Join(fileutil.GetWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json") + rewriterTemplate, err := fileutil.ReadJinja2PromptTemplate(filePath) + if err != nil { + return nil, err + } + + rewriter, err := builtin.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate) + if err != nil { + return nil, err + } + + return rewriter, nil +} diff --git a/backend/infra/messages2query/messages_to_query.go b/backend/infra/document/messages2query/messages_to_query.go similarity index 100% rename from backend/infra/messages2query/messages_to_query.go rename to backend/infra/document/messages2query/messages_to_query.go diff --git a/backend/infra/messages2query/options.go b/backend/infra/document/messages2query/options.go similarity index 100% rename from backend/infra/messages2query/options.go rename to backend/infra/document/messages2query/options.go diff --git a/backend/infra/document/nl2sql/impl/impl.go b/backend/infra/document/nl2sql/impl/impl.go new file mode 100644 index 000000000..99cea41af --- /dev/null +++ b/backend/infra/document/nl2sql/impl/impl.go @@ -0,0 +1,48 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "context" + "path/filepath" + + "github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel" + "github.com/coze-dev/coze-studio/backend/bizpkg/fileutil" + "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql" + "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql/impl/builtin" +) + +type NL2SQL = nl2sql.NL2SQL + +func New(ctx context.Context) (nl2sql.NL2SQL, error) { + n2sChatModel, _, err := buildinmodel.GetBuiltinChatModel(ctx, "NL2SQL_") + if err != nil { + return nil, err + } + + filePath := filepath.Join(fileutil.GetWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json") + n2sTemplate, err := fileutil.ReadJinja2PromptTemplate(filePath) + if err != nil { + return nil, err + } + + n2s, err := builtin.NewNL2SQL(ctx, n2sChatModel, n2sTemplate) + if err != nil { + return nil, err + } + + return n2s, nil +} diff --git a/backend/infra/document/ocr/impl/impl.go b/backend/infra/document/ocr/impl/impl.go new file mode 100644 index 000000000..0c3dbccf1 --- /dev/null +++ b/backend/infra/document/ocr/impl/impl.go @@ -0,0 +1,55 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "net/http" + "os" + + "github.com/volcengine/volc-sdk-golang/service/visual" + + "github.com/coze-dev/coze-studio/backend/infra/document/ocr" + "github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/ppocr" + "github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/veocr" + "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/types/consts" +) + +type OCR = ocr.OCR + +func New() ocr.OCR { + var ocr ocr.OCR + switch os.Getenv(consts.OCRType) { + case "ve": + ocrAK := os.Getenv(consts.VeOCRAK) + ocrSK := os.Getenv(consts.VeOCRSK) + if ocrAK == "" || ocrSK == "" { + logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well") + } + inst := visual.NewInstance() + inst.Client.SetAccessKey(ocrAK) + inst.Client.SetSecretKey(ocrSK) + ocr = veocr.NewOCR(&veocr.Config{Client: inst}) + case "paddleocr": + url := os.Getenv(consts.PPOCRAPIURL) + client := &http.Client{} + ocr = ppocr.NewOCR(&ppocr.Config{Client: client, URL: url}) + default: + // accept ocr not configured + } + + return ocr +} diff --git a/backend/infra/document/parser/impl/builtin/manager.go b/backend/infra/document/parser/impl/builtin/manager.go index 42de67556..a8450ba8d 100644 --- a/backend/infra/document/parser/impl/builtin/manager.go +++ b/backend/infra/document/parser/impl/builtin/manager.go @@ -19,11 +19,11 @@ package builtin import ( "fmt" + "github.com/coze-dev/coze-studio/backend/bizpkg/fileutil" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/document/parser" "github.com/coze-dev/coze-studio/backend/infra/storage" - "github.com/coze-dev/coze-studio/backend/pkg/goutil" ) func NewManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager { @@ -52,13 +52,13 @@ func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) { switch config.FileExtension { case parser.FileExtensionPDF: - pFn = ParseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_pdf.py")) + pFn = ParseByPython(config, m.storage, m.ocr, fileutil.GetPython3Path(), fileutil.GetPythonFilePath("parse_pdf.py")) case parser.FileExtensionTXT: pFn = ParseText(config) case parser.FileExtensionMarkdown: pFn = ParseMarkdown(config, m.storage, m.ocr) case parser.FileExtensionDocx: - pFn = ParseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py")) + pFn = ParseByPython(config, m.storage, m.ocr, fileutil.GetPython3Path(), fileutil.GetPythonFilePath("parse_docx.py")) case parser.FileExtensionCSV: pFn = ParseCSV(config) case parser.FileExtensionXLSX: diff --git a/backend/infra/document/parser/impl/impl.go b/backend/infra/document/parser/impl/impl.go new file mode 100644 index 000000000..5f9ac7ac7 --- /dev/null +++ b/backend/infra/document/parser/impl/impl.go @@ -0,0 +1,59 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "context" + "fmt" + "net/http" + "os" + + "github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel" + "github.com/coze-dev/coze-studio/backend/infra/document/ocr" + "github.com/coze-dev/coze-studio/backend/infra/document/parser" + "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/builtin" + "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/ppstructure" + "github.com/coze-dev/coze-studio/backend/infra/storage" + "github.com/coze-dev/coze-studio/backend/types/consts" +) + +type Manager = parser.Manager + +func New(ctx context.Context, storage storage.Storage, ocr ocr.OCR) (Manager, error) { + imageAnnotationModel, _, err := buildinmodel.GetBuiltinChatModel(ctx, "IA_") + if err != nil { + return nil, fmt.Errorf("get builtin chat model failed, err=%w", err) + } + + var parserManager parser.Manager + parserType := os.Getenv(consts.ParserType) + switch parserType { + case "builtin", "": + parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel) + case "paddleocr": + url := os.Getenv(consts.PPStructureAPIURL) + client := &http.Client{} + apiConfig := &ppstructure.APIConfig{ + Client: client, + URL: url, + } + parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel) + default: + return nil, fmt.Errorf("parser type %s not supported", parserType) + } + + return parserManager, nil +} diff --git a/backend/infra/document/parser/impl/ppstructure/manager.go b/backend/infra/document/parser/impl/ppstructure/manager.go index 26d1e10d9..02a112c23 100644 --- a/backend/infra/document/parser/impl/ppstructure/manager.go +++ b/backend/infra/document/parser/impl/ppstructure/manager.go @@ -19,12 +19,12 @@ package ppstructure import ( "fmt" + "github.com/coze-dev/coze-studio/backend/bizpkg/fileutil" "github.com/coze-dev/coze-studio/backend/infra/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/document/parser" "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/builtin" "github.com/coze-dev/coze-studio/backend/infra/storage" - "github.com/coze-dev/coze-studio/backend/pkg/goutil" ) func NewManager(apiConfig *APIConfig, ocr ocr.OCR, storage storage.Storage, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager { @@ -64,7 +64,7 @@ func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) { pFn = builtin.ParseMarkdown(config, m.storage, m.ocr) return &builtin.Parser{ParseFn: pFn}, nil case parser.FileExtensionDocx: - pFn = builtin.ParseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py")) + pFn = builtin.ParseByPython(config, m.storage, m.ocr, fileutil.GetPython3Path(), fileutil.GetPythonFilePath("parse_docx.py")) return &builtin.Parser{ParseFn: pFn}, nil case parser.FileExtensionCSV: pFn = builtin.ParseCSV(config) diff --git a/backend/infra/document/progressbar/impl/progressbar/impl.go b/backend/infra/document/progressbar/impl/progressbar/impl.go index 9ca6fab53..ff66b8d91 100644 --- a/backend/infra/document/progressbar/impl/progressbar/impl.go +++ b/backend/infra/document/progressbar/impl/progressbar/impl.go @@ -42,8 +42,8 @@ const ( ProgressBarTotalNumRedisKey = "RedisBiz.Knowledge_ProgressBar_TotalNum_%d" ProgressBarProcessedNumRedisKey = "RedisBiz.Knowledge_ProgressBar_ProcessedNum_%d" DefaultProcessTime = 300 - ProcessDone = 100 - ProcessInit = 0 + ProcessDone = progressbar.ProcessDone + ProcessInit = progressbar.ProcessInit ) func NewProgressBar(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) progressbar.ProgressBar { @@ -142,6 +142,9 @@ func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainS if ptr.From(startTime) == 0 { remainSec = DefaultProcessTime } else { + if ptr.From(processedNum) == 0 { + return + } usedSec := time.Now().Unix() - ptr.From(startTime) remainSec = int(float64(ptr.From(totalNum)-ptr.From(processedNum)) / float64(ptr.From(processedNum)) * float64(usedSec)) } diff --git a/backend/infra/document/progressbar/interface.go b/backend/infra/document/progressbar/interface.go index 645600d3a..dd3cf9e0c 100644 --- a/backend/infra/document/progressbar/interface.go +++ b/backend/infra/document/progressbar/interface.go @@ -16,7 +16,11 @@ package progressbar -import "context" +import ( + "context" + + "github.com/coze-dev/coze-studio/backend/infra/cache" +) // ProgressBar is the interface for the progress bar. type ProgressBar interface { @@ -24,3 +28,10 @@ type ProgressBar interface { ReportError(err error) error GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string) } + +var New func(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) ProgressBar + +const ( + ProcessDone = 100 + ProcessInit = 0 +) diff --git a/backend/infra/document/rerank/impl/impl.go b/backend/infra/document/rerank/impl/impl.go new file mode 100644 index 000000000..7ede5d1db --- /dev/null +++ b/backend/infra/document/rerank/impl/impl.go @@ -0,0 +1,48 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "os" + + "github.com/coze-dev/coze-studio/backend/infra/document/rerank" + "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/rrf" + "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/vikingdb" +) + +type Reranker = rerank.Reranker + +func New() Reranker { + rerankerType := os.Getenv("RERANK_TYPE") + switch rerankerType { + case "vikingdb": + return vikingdb.NewReranker(getVikingRerankerConfig()) + case "rrf": + return rrf.NewRRFReranker(0) + default: + return rrf.NewRRFReranker(0) + } +} + +func getVikingRerankerConfig() *vikingdb.Config { + return &vikingdb.Config{ + AK: os.Getenv("VIKINGDB_RERANK_AK"), + SK: os.Getenv("VIKINGDB_RERANK_SK"), + Domain: os.Getenv("VIKINGDB_RERANK_HOST"), + Region: os.Getenv("VIKINGDB_RERANK_REGION"), + Model: os.Getenv("VIKINGDB_RERANK_MODEL"), + } +} diff --git a/backend/infra/document/searchstore/impl/impl.go b/backend/infra/document/searchstore/impl/impl.go new file mode 100644 index 000000000..28271238b --- /dev/null +++ b/backend/infra/document/searchstore/impl/impl.go @@ -0,0 +1,433 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package impl + +import ( + "context" + "fmt" + "os" + "strconv" + "time" + + "github.com/cloudwego/eino-ext/components/embedding/gemini" + "github.com/cloudwego/eino-ext/components/embedding/ollama" + "github.com/cloudwego/eino-ext/components/embedding/openai" + "github.com/milvus-io/milvus/client/v2/milvusclient" + "google.golang.org/genai" + + "github.com/coze-dev/coze-studio/backend/infra/document/searchstore" + "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/elasticsearch" + "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/milvus" + searchstoreOceanbase "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase" + "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb" + "github.com/coze-dev/coze-studio/backend/infra/embedding" + "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark" + "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http" + "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap" + "github.com/coze-dev/coze-studio/backend/infra/es/impl/es" + "github.com/coze-dev/coze-studio/backend/infra/oceanbase" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/logs" +) + +type Manager = searchstore.Manager + +func New(ctx context.Context, es es.Client) ([]Manager, error) { + // es full text search + esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es}) + + // vector search + mgr, err := getVectorStore(ctx) + if err != nil { + return nil, fmt.Errorf("init vector store failed, err=%w", err) + } + + return []searchstore.Manager{esSearchstoreManager, mgr}, nil +} + +func getVectorStore(ctx context.Context) (searchstore.Manager, error) { + vsType := os.Getenv("VECTOR_STORE_TYPE") + + switch vsType { + case "milvus": + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + var ( + milvusAddr = os.Getenv("MILVUS_ADDR") + user = os.Getenv("MILVUS_USER") + password = os.Getenv("MILVUS_PASSWORD") + milvusToken = os.Getenv("MILVUS_TOKEN") + ) + mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + Username: user, + Password: password, + APIKey: milvusToken, + }) + if err != nil { + return nil, fmt.Errorf("init milvus client failed, err=%w", err) + } + + emb, err := getEmbedding(ctx) + if err != nil { + return nil, fmt.Errorf("init milvus embedding failed, err=%w", err) + } + + mgr, err := milvus.NewManager(&milvus.ManagerConfig{ + Client: mc, + Embedding: emb, + EnableHybrid: ptr.Of(true), + }) + if err != nil { + return nil, fmt.Errorf("init milvus vector store failed, err=%w", err) + } + + return mgr, nil + case "vikingdb": + var ( + host = os.Getenv("VIKING_DB_HOST") + region = os.Getenv("VIKING_DB_REGION") + ak = os.Getenv("VIKING_DB_AK") + sk = os.Getenv("VIKING_DB_SK") + scheme = os.Getenv("VIKING_DB_SCHEME") + modelName = os.Getenv("VIKING_DB_MODEL_NAME") + ) + if ak == "" || sk == "" { + return nil, fmt.Errorf("invalid vikingdb ak / sk") + } + if host == "" { + host = "api-vikingdb.volces.com" + } + if region == "" { + region = "cn-beijing" + } + if scheme == "" { + scheme = "https" + } + + var embConfig *vikingdb.VikingEmbeddingConfig + if modelName != "" { + embName := vikingdb.VikingEmbeddingModelName(modelName) + if embName.Dimensions() == 0 { + return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName) + } + embConfig = &vikingdb.VikingEmbeddingConfig{ + UseVikingEmbedding: true, + EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse, + ModelName: embName, + ModelVersion: embName.ModelVersion(), + DenseWeight: ptr.Of(0.2), + BuiltinEmbedding: nil, + } + } else { + builtinEmbedding, err := getEmbedding(ctx) + if err != nil { + return nil, fmt.Errorf("builtint embedding init failed, err=%w", err) + } + + embConfig = &vikingdb.VikingEmbeddingConfig{ + UseVikingEmbedding: false, + EnableHybrid: false, + BuiltinEmbedding: builtinEmbedding, + } + } + + svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme) + mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{ + Service: svc, + IndexingConfig: nil, // use default config + EmbeddingConfig: embConfig, + }) + if err != nil { + return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err) + } + + return mgr, nil + + case "oceanbase": + emb, err := getEmbedding(ctx) + if err != nil { + return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err) + } + + var ( + host = os.Getenv("OCEANBASE_HOST") + port = os.Getenv("OCEANBASE_PORT") + user = os.Getenv("OCEANBASE_USER") + password = os.Getenv("OCEANBASE_PASSWORD") + database = os.Getenv("OCEANBASE_DATABASE") + ) + if host == "" || port == "" || user == "" || password == "" || database == "" { + return nil, fmt.Errorf("invalid oceanbase configuration: host, port, user, password, database are required") + } + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", + user, password, host, port, database) + + client, err := oceanbase.NewOceanBaseClient(dsn) + if err != nil { + return nil, fmt.Errorf("init oceanbase client failed, err=%w", err) + } + + if err := client.InitDatabase(ctx); err != nil { + return nil, fmt.Errorf("init oceanbase database failed, err=%w", err) + } + + // Get configuration from environment variables with defaults + batchSize := 100 + if bs := os.Getenv("OCEANBASE_BATCH_SIZE"); bs != "" { + if bsInt, err := strconv.Atoi(bs); err == nil { + batchSize = bsInt + } + } + + enableCache := true + if ec := os.Getenv("OCEANBASE_ENABLE_CACHE"); ec != "" { + if ecBool, err := strconv.ParseBool(ec); err == nil { + enableCache = ecBool + } + } + + cacheTTL := 300 * time.Second + if ct := os.Getenv("OCEANBASE_CACHE_TTL"); ct != "" { + if ctInt, err := strconv.Atoi(ct); err == nil { + cacheTTL = time.Duration(ctInt) * time.Second + } + } + + maxConnections := 100 + if mc := os.Getenv("OCEANBASE_MAX_CONNECTIONS"); mc != "" { + if mcInt, err := strconv.Atoi(mc); err == nil { + maxConnections = mcInt + } + } + + connTimeout := 30 * time.Second + if ct := os.Getenv("OCEANBASE_CONN_TIMEOUT"); ct != "" { + if ctInt, err := strconv.Atoi(ct); err == nil { + connTimeout = time.Duration(ctInt) * time.Second + } + } + + managerConfig := &searchstoreOceanbase.ManagerConfig{ + Client: client, + Embedding: emb, + BatchSize: batchSize, + EnableCache: enableCache, + CacheTTL: cacheTTL, + MaxConnections: maxConnections, + ConnTimeout: connTimeout, + } + mgr, err := searchstoreOceanbase.NewManager(managerConfig) + if err != nil { + return nil, fmt.Errorf("init oceanbase vector store failed, err=%w", err) + } + return mgr, nil + + default: + return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType) + } +} + +func getEmbedding(ctx context.Context) (embedding.Embedder, error) { + var batchSize int + if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil { + logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100") + batchSize = 100 + } else { + batchSize = int(bs) + } + + var emb embedding.Embedder + + switch os.Getenv("EMBEDDING_TYPE") { + case "openai": + var ( + openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL") + openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL") + openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY") + openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE") + openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION") + openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS") + openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS") + ) + + byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure) + if err != nil { + return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err) + } + + dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err) + } + + openAICfg := &openai.EmbeddingConfig{ + APIKey: openAIEmbeddingApiKey, + ByAzure: byAzure, + BaseURL: openAIEmbeddingBaseURL, + APIVersion: openAIEmbeddingApiVersion, + Model: openAIEmbeddingModel, + // Dimensions: ptr.Of(int(dims)), + } + reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0) + if reqDims > 0 { + // some openai model not support request dims + openAICfg.Dimensions = ptr.Of(int(reqDims)) + } + + emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init openai embedding failed, err=%w", err) + } + + case "ark": + var ( + arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL") + arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL") + arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY") + // deprecated: use ARK_EMBEDDING_API_KEY instead + // ARK_EMBEDDING_AK will be removed in the future + arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK") + arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS") + arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE") + ) + + dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err) + } + + apiType := ark.APITypeText + if arkEmbeddingAPIType != "" { + if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal { + return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t) + } else { + apiType = t + } + } + + emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{ + APIKey: func() string { + if arkEmbeddingApiKey != "" { + return arkEmbeddingApiKey + } + return arkEmbeddingAK + }(), + Model: arkEmbeddingModel, + BaseURL: arkEmbeddingBaseURL, + APIType: &apiType, + }, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) + } + + case "ollama": + var ( + ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL") + ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL") + ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS") + ) + + dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err) + } + + emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{ + BaseURL: ollamaEmbeddingBaseURL, + Model: ollamaEmbeddingModel, + }, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) + } + case "gemini": + var ( + geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL") + geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL") + geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY") + geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS") + geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI + geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT") + geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION") + ) + + if len(geminiEmbeddingModel) == 0 { + return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required") + } + if len(geminiEmbeddingApiKey) == 0 { + return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required") + } + if len(geminiEmbeddingDims) == 0 { + return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required") + } + if len(geminiEmbeddingBackend) == 0 { + return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required") + } + + dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64) + if convErr != nil { + return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr) + } + + backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64) + if convErr != nil { + return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr) + } + + geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: geminiEmbeddingApiKey, + Backend: genai.Backend(backend), + Project: geminiEmbeddingProject, + Location: geminiEmbeddingLocation, + HTTPOptions: genai.HTTPOptions{ + BaseURL: geminiEmbeddingBaseURL, + }, + }) + if err != nil { + return nil, fmt.Errorf("init gemini client failed, err=%w", err) + } + + emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{ + Client: geminiCli, + Model: geminiEmbeddingModel, + OutputDimensionality: ptr.Of(int32(dims)), + }, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init gemini embedding failed, err=%w", err) + } + case "http": + var ( + httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR") + httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS") + ) + dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init http embedding dims failed, err=%w", err) + } + emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init http embedding failed, err=%w", err) + } + + default: + return nil, fmt.Errorf("init knowledge embedding failed, type not configured") + } + + return emb, nil +} diff --git a/backend/infra/eventbus/eventbus.go b/backend/infra/eventbus/eventbus.go index e966baf84..1d1e284ad 100644 --- a/backend/infra/eventbus/eventbus.go +++ b/backend/infra/eventbus/eventbus.go @@ -18,7 +18,7 @@ package eventbus import "context" -//go:generate mockgen -destination ../../../internal/mock/infra/eventbus/eventbus_mock.go -package mock -source eventbus.go Factory +//go:generate mockgen -destination ../../internal/mock/infra/eventbus/eventbus_mock.go -package mock -source eventbus.go Factory type Producer interface { Send(ctx context.Context, body []byte, opts ...SendOpt) error BatchSend(ctx context.Context, bodyArr [][]byte, opts ...SendOpt) error diff --git a/backend/infra/eventbus/impl/eventbus.go b/backend/infra/eventbus/impl/eventbus.go index 3bc7cd38a..5061c79b7 100644 --- a/backend/infra/eventbus/impl/eventbus.go +++ b/backend/infra/eventbus/impl/eventbus.go @@ -77,3 +77,35 @@ func NewProducer(nameServer, topic, group string, retries int) (eventbus.Produce return nil, fmt.Errorf("invalid mq type: %s , only support nsq, kafka, rmq, pulsar", tp) } + +func InitResourceEventBusProducer() (eventbus.Producer, error) { + nameServer := os.Getenv(consts.MQServer) + resourceEventBusProducer, err := NewProducer(nameServer, + consts.RMQTopicResource, consts.RMQConsumeGroupResource, 1) + if err != nil { + return nil, fmt.Errorf("init resource producer failed, err=%w", err) + } + + return resourceEventBusProducer, nil +} + +func InitAppEventProducer() (eventbus.Producer, error) { + nameServer := os.Getenv(consts.MQServer) + appEventProducer, err := NewProducer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, 1) + if err != nil { + return nil, fmt.Errorf("init app producer failed, err=%w", err) + } + + return appEventProducer, nil +} + +func InitKnowledgeEventBusProducer() (eventbus.Producer, error) { + nameServer := os.Getenv(consts.MQServer) + + knowledgeProducer, err := NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2) + if err != nil { + return nil, fmt.Errorf("init knowledge producer failed, err=%w", err) + } + + return knowledgeProducer, nil +} diff --git a/backend/infra/imagex/impl/veimagex/veimagex.go b/backend/infra/imagex/impl/veimagex/veimagex.go index c3b613e4a..4fc83edbc 100644 --- a/backend/infra/imagex/impl/veimagex/veimagex.go +++ b/backend/infra/imagex/impl/veimagex/veimagex.go @@ -19,6 +19,7 @@ package veimagex import ( "context" "errors" + "os" "time" "github.com/volcengine/volc-sdk-golang/base" @@ -26,8 +27,20 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/imagex" "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/types/consts" ) +func NewDefault() (imagex.ImageX, error) { + return New( + os.Getenv(consts.VeImageXAK), + os.Getenv(consts.VeImageXSK), + os.Getenv(consts.VeImageXDomain), + os.Getenv(consts.VeImageXUploadHost), + os.Getenv(consts.VeImageXTemplate), + []string{os.Getenv(consts.VeImageXServerID)}, + ) +} + func New(ak, sk, domain, uploadHost, template string, serverIDs []string) (imagex.ImageX, error) { instance := veimagex.DefaultInstance instance.SetCredential(base.Credentials{ diff --git a/backend/infra/rdb/impl/rdb/mysql.go b/backend/infra/rdb/impl/rdb/mysql.go index 44cb1ab1b..fd5b03d78 100644 --- a/backend/infra/rdb/impl/rdb/mysql.go +++ b/backend/infra/rdb/impl/rdb/mysql.go @@ -27,8 +27,7 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/idgen" "github.com/coze-dev/coze-studio/backend/infra/rdb" "github.com/coze-dev/coze-studio/backend/infra/rdb/entity" - sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/sqlparser" - "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" "github.com/coze-dev/coze-studio/backend/pkg/logs" ) @@ -657,12 +656,12 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques } } - operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL) + operation, err := sqlparser.New().GetSQLOperation(processedSQL) if err != nil { return nil, err } - if operation != sqlparsercontract.OperationTypeSelect { + if operation != sqlparser.OperationTypeSelect { result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...) if result.Error != nil { return nil, fmt.Errorf("failed to execute SQL: %v", result.Error) diff --git a/backend/infra/rdb/impl/rdb/mysql_test.go b/backend/infra/rdb/impl/rdb/mysql_test.go index 094c1a025..d45d14f71 100644 --- a/backend/infra/rdb/impl/rdb/mysql_test.go +++ b/backend/infra/rdb/impl/rdb/mysql_test.go @@ -30,6 +30,8 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/rdb" entity2 "github.com/coze-dev/coze-studio/backend/infra/rdb/entity" + "github.com/coze-dev/coze-studio/backend/infra/sqlparser" + sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser" mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/idgen" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) @@ -515,6 +517,8 @@ func TestSelectData(t *testing.T) { func TestExecuteSQL(t *testing.T) { t.Run("success", func(t *testing.T) { + sqlparser.New = sqlparserImpl.NewSQLParser + db, svc := setupTestDB(t) defer cleanupTestDB(t, db, "test_sql_table") diff --git a/backend/infra/sqlparser/sql_parser.go b/backend/infra/sqlparser/sql_parser.go index 796c1c250..3a7ebc4d3 100644 --- a/backend/infra/sqlparser/sql_parser.go +++ b/backend/infra/sqlparser/sql_parser.go @@ -78,3 +78,5 @@ type SQLParser interface { // AddSelectFieldsToSelectSQL add select fields to select sql AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error) } + +var New func() SQLParser diff --git a/backend/infra/storage/impl/minio/minio.go b/backend/infra/storage/impl/minio/minio.go index 8295c50d0..d18c50689 100644 --- a/backend/infra/storage/impl/minio/minio.go +++ b/backend/infra/storage/impl/minio/minio.go @@ -50,7 +50,7 @@ func New(ctx context.Context, endpoint, accessKeyID, secretAccessKey, bucketName return m, nil } -func getMinioClient(_ context.Context, endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*minioClient, error) { +func getMinioClient(ctx context.Context, endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*minioClient, error) { client, err := minio.New(endpoint, &minio.Options{ Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""), Secure: useSSL, @@ -67,7 +67,7 @@ func getMinioClient(_ context.Context, endpoint, accessKeyID, secretAccessKey, b endpoint: endpoint, } - err = m.createBucketIfNeed(context.Background(), client, bucketName, "cn-north-1") + err = m.createBucketIfNeed(ctx, client, bucketName, "cn-north-1") if err != nil { return nil, fmt.Errorf("init minio client failed %v", err) } diff --git a/backend/infra/storage/impl/s3/s3.go b/backend/infra/storage/impl/s3/s3.go index bb0ba2529..d4109d887 100644 --- a/backend/infra/storage/impl/s3/s3.go +++ b/backend/infra/storage/impl/s3/s3.go @@ -62,7 +62,7 @@ func getS3Client(ctx context.Context, ak, sk, bucketName, endpoint, region strin }, nil }) cfg, err := config.LoadDefaultConfig( - context.TODO(), + ctx, config.WithCredentialsProvider(creds), config.WithEndpointResolverWithOptions(customResolver), config.WithRegion("auto"), diff --git a/backend/pkg/jsoncache/jsoncache.go b/backend/pkg/jsoncache/jsoncache.go index 63e8c6a9d..85e1aae48 100644 --- a/backend/pkg/jsoncache/jsoncache.go +++ b/backend/pkg/jsoncache/jsoncache.go @@ -74,7 +74,8 @@ func (g *JsonCache[T]) Get(ctx context.Context, k string) (*T, error) { } func (g *JsonCache[T]) Delete(ctx context.Context, k string) error { - if err := g.cache.Del(ctx, k).Err(); err != nil { + key := g.prefix + k + if err := g.cache.Del(ctx, key).Err(); err != nil { return fmt.Errorf("failed to delete key %s: %w", k, err) } return nil