refactor: optimize app infra component initialization (#2294)
This commit is contained in:
@ -83,7 +83,6 @@ import (
|
||||
pluginmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/model"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
|
||||
agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
@ -3837,7 +3836,7 @@ func TestNodeDebugLoop(t *testing.T) {
|
||||
}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
code.SetCodeRunner(runner)
|
||||
coderunner.SetCodeRunner(runner)
|
||||
id := r.load("loop_with_object_input.json")
|
||||
exeID := r.nodeDebug(id, "122149",
|
||||
withNDInput(map[string]string{"input": `[{"a":"1"},{"a":"2"}]`}))
|
||||
@ -4154,7 +4153,7 @@ func TestCodeExceptionBranch(t *testing.T) {
|
||||
id := r.load("exception/code_exception_branch.json")
|
||||
|
||||
mockey.PatchConvey("exception branch", func() {
|
||||
code.SetCodeRunner(direct.NewRunner())
|
||||
coderunner.SetCodeRunner(direct.NewRunner())
|
||||
|
||||
exeID := r.testRun(id, map[string]string{"input": "hello"})
|
||||
e := r.getProcess(id, exeID)
|
||||
@ -4167,7 +4166,7 @@ func TestCodeExceptionBranch(t *testing.T) {
|
||||
|
||||
mockey.PatchConvey("normal branch", func() {
|
||||
mockCodeRunner := mockcode.NewMockRunner(r.ctrl)
|
||||
mockey.Mock(code.GetCodeRunner).Return(mockCodeRunner).Build()
|
||||
mockey.Mock(coderunner.GetCodeRunner).Return(mockCodeRunner).Build()
|
||||
mockCodeRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(&coderunner.RunResponse{
|
||||
Result: map[string]any{
|
||||
"key0": "value0",
|
||||
@ -4891,7 +4890,7 @@ func TestHttpImplicitDependencies(t *testing.T) {
|
||||
}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
code.SetCodeRunner(runner)
|
||||
coderunner.SetCodeRunner(runner)
|
||||
|
||||
mockey.PatchConvey("test http node implicit dependencies", func() {
|
||||
input := map[string]string{
|
||||
@ -6059,7 +6058,7 @@ func TestWorkflowRunWithFiles(t *testing.T) {
|
||||
}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
mockey.Mock(code.GetCodeRunner).Return(runner).Build()
|
||||
mockey.Mock(coderunner.GetCodeRunner).Return(runner).Build()
|
||||
|
||||
idStr := r.load("workflow_wf_file_name.json")
|
||||
r.publish(idStr, "v0.1.1", true)
|
||||
|
||||
@ -70,8 +70,12 @@ import (
|
||||
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel/impl/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/checkpoint"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/progressbar"
|
||||
progressBarImpl "github.com/coze-dev/coze-studio/backend/infra/document/progressbar/impl/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
implEventbus "github.com/coze-dev/coze-studio/backend/infra/eventbus/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
)
|
||||
|
||||
type eventbusImpl struct {
|
||||
@ -116,6 +120,9 @@ func Init(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
progressbar.New = progressBarImpl.NewProgressBar
|
||||
sqlparser.New = sqlparserImpl.NewSQLParser
|
||||
|
||||
eventbus := initEventBus(infra)
|
||||
|
||||
basicServices, err := initBasicServices(ctx, infra, eventbus)
|
||||
|
||||
@ -18,65 +18,30 @@ package appinfra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/genai"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/cache/impl/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/direct"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/sandbox"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/nl2sql"
|
||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql/impl/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/ppocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/veocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/ppstructure"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/rrf"
|
||||
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/elasticsearch"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/milvus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
|
||||
embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
|
||||
coderunner "github.com/coze-dev/coze-studio/backend/infra/coderunner/impl"
|
||||
messages2query "github.com/coze-dev/coze-studio/backend/infra/document/messages2query/impl"
|
||||
nl2sql "github.com/coze-dev/coze-studio/backend/infra/document/nl2sql/impl"
|
||||
ocr "github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl"
|
||||
parser "github.com/coze-dev/coze-studio/backend/infra/document/parser/impl"
|
||||
rerank "github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl"
|
||||
searchstore "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/es/impl/es"
|
||||
eventbus "github.com/coze-dev/coze-studio/backend/infra/eventbus/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/idgen/impl/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/imagex/impl/veimagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/messages2query"
|
||||
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/messages2query/impl/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/modelmgr"
|
||||
oceanbaseClient "github.com/coze-dev/coze-studio/backend/infra/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/orm/impl/mysql"
|
||||
storage "github.com/coze-dev/coze-studio/backend/infra/storage/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
@ -132,29 +97,29 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
return nil, fmt.Errorf("init imagex client failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.ResourceEventProducer, err = initResourceEventBusProducer()
|
||||
deps.ResourceEventProducer, err = eventbus.InitResourceEventBusProducer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init resource event bus producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.AppEventProducer, err = initAppEventProducer()
|
||||
deps.AppEventProducer, err = eventbus.InitAppEventProducer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init app event producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.KnowledgeEventProducer, err = initKnowledgeEventBusProducer()
|
||||
deps.KnowledgeEventProducer, err = eventbus.InitKnowledgeEventBusProducer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.Reranker = initReranker()
|
||||
deps.Reranker = rerank.New()
|
||||
|
||||
deps.Rewriter, err = initRewriter(ctx)
|
||||
deps.Rewriter, err = messages2query.New(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init rewriter failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.NL2SQL, err = initNL2SQL(ctx)
|
||||
deps.NL2SQL, err = nl2sql.New(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init nl2sql failed, err=%w", err)
|
||||
}
|
||||
@ -164,17 +129,12 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
return nil, fmt.Errorf("init model manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.CodeRunner = initCodeRunner()
|
||||
deps.CodeRunner = coderunner.New()
|
||||
|
||||
deps.OCR = initOCR()
|
||||
|
||||
imageAnnotationModel, _, err := getBuiltinChatModel(ctx, "IA_")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
|
||||
}
|
||||
deps.OCR = ocr.New()
|
||||
|
||||
var ok bool
|
||||
deps.WorkflowBuildInChatModel, ok, err = getBuiltinChatModel(ctx, "WKR_")
|
||||
deps.WorkflowBuildInChatModel, ok, err = buildinmodel.GetBuiltinChatModel(ctx, "WKR_")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get workflow builtin chat model failed, err=%w", err)
|
||||
}
|
||||
@ -183,12 +143,12 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
|
||||
deps.ParserManager, err = parser.New(ctx, deps.TOSClient, deps.OCR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init parser manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient)
|
||||
deps.SearchStoreManagers, err = searchstore.New(ctx, deps.ESClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init search store managers failed, err=%w", err)
|
||||
}
|
||||
@ -196,602 +156,11 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
return deps, nil
|
||||
}
|
||||
|
||||
func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) {
|
||||
// es full text search
|
||||
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es})
|
||||
|
||||
// vector search
|
||||
mgr, err := getVectorStore(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
||||
}
|
||||
|
||||
func initReranker() rerank.Reranker {
|
||||
rerankerType := os.Getenv("RERANK_TYPE")
|
||||
switch rerankerType {
|
||||
case "vikingdb":
|
||||
return vikingReranker.NewReranker(getVikingRerankerConfig())
|
||||
case "rrf":
|
||||
return rrf.NewRRFReranker(0)
|
||||
default:
|
||||
return rrf.NewRRFReranker(0)
|
||||
}
|
||||
}
|
||||
func getVikingRerankerConfig() *vikingReranker.Config {
|
||||
return &vikingReranker.Config{
|
||||
AK: os.Getenv("VIKINGDB_RERANK_AK"),
|
||||
SK: os.Getenv("VIKINGDB_RERANK_SK"),
|
||||
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
|
||||
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
|
||||
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
|
||||
}
|
||||
}
|
||||
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
|
||||
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
||||
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rewriter, err := builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rewriter, nil
|
||||
}
|
||||
|
||||
func getWorkingDirectory() string {
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
||||
root = os.Getenv("PWD")
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
||||
b, err := os.ReadFile(jsonFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m2qMessages []*schema.Message
|
||||
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
||||
for i := range m2qMessages {
|
||||
tpl[i] = m2qMessages[i]
|
||||
}
|
||||
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
||||
}
|
||||
|
||||
func initNL2SQL(ctx context.Context) (nl2sql.NL2SQL, error) {
|
||||
n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json")
|
||||
n2sTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n2s, err := builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return n2s, nil
|
||||
}
|
||||
|
||||
func initImageX(ctx context.Context) (imagex.ImageX, error) {
|
||||
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
|
||||
if uploadComponentType != consts.FileUploadComponentTypeImagex {
|
||||
return storage.NewImagex(ctx)
|
||||
}
|
||||
return veimagex.New(
|
||||
os.Getenv(consts.VeImageXAK),
|
||||
os.Getenv(consts.VeImageXSK),
|
||||
os.Getenv(consts.VeImageXDomain),
|
||||
os.Getenv(consts.VeImageXUploadHost),
|
||||
os.Getenv(consts.VeImageXTemplate),
|
||||
[]string{os.Getenv(consts.VeImageXServerID)},
|
||||
)
|
||||
}
|
||||
|
||||
func initResourceEventBusProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
resourceEventBusProducer, err := eventbus.NewProducer(nameServer,
|
||||
consts.RMQTopicResource, consts.RMQConsumeGroupResource, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init resource producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return resourceEventBusProducer, nil
|
||||
}
|
||||
|
||||
func initAppEventProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
appEventProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init app producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return appEventProducer, nil
|
||||
}
|
||||
|
||||
func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
|
||||
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return knowledgeProducer, nil
|
||||
}
|
||||
|
||||
func initCodeRunner() coderunner.Runner {
|
||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
||||
case "sandbox":
|
||||
getAndSplit := func(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(v, ",")
|
||||
}
|
||||
config := &sandbox.Config{
|
||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
|
||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
|
||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
|
||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
|
||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
|
||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
|
||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||
TimeoutSeconds: 0,
|
||||
MemoryLimitMB: 0,
|
||||
}
|
||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
|
||||
config.TimeoutSeconds = f
|
||||
} else {
|
||||
config.TimeoutSeconds = 60.0
|
||||
}
|
||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
|
||||
config.MemoryLimitMB = mem
|
||||
} else {
|
||||
config.MemoryLimitMB = 100
|
||||
}
|
||||
return sandbox.NewRunner(config)
|
||||
default:
|
||||
return direct.NewRunner()
|
||||
}
|
||||
}
|
||||
|
||||
func initOCR() ocr.OCR {
|
||||
var ocr ocr.OCR
|
||||
switch os.Getenv(consts.OCRType) {
|
||||
case "ve":
|
||||
ocrAK := os.Getenv(consts.VeOCRAK)
|
||||
ocrSK := os.Getenv(consts.VeOCRSK)
|
||||
if ocrAK == "" || ocrSK == "" {
|
||||
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
|
||||
}
|
||||
inst := visual.NewInstance()
|
||||
inst.Client.SetAccessKey(ocrAK)
|
||||
inst.Client.SetSecretKey(ocrSK)
|
||||
ocr = veocr.NewOCR(&veocr.Config{Client: inst})
|
||||
case "paddleocr":
|
||||
url := os.Getenv(consts.PPOCRAPIURL)
|
||||
client := &http.Client{}
|
||||
ocr = ppocr.NewOCR(&ppocr.Config{Client: client, URL: url})
|
||||
default:
|
||||
// accept ocr not configured
|
||||
}
|
||||
|
||||
return ocr
|
||||
}
|
||||
|
||||
func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) {
|
||||
var parserManager parser.Manager
|
||||
parserType := os.Getenv(consts.ParserType)
|
||||
switch parserType {
|
||||
case "builtin", "":
|
||||
parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel)
|
||||
case "paddleocr":
|
||||
url := os.Getenv(consts.PPStructureAPIURL)
|
||||
client := &http.Client{}
|
||||
apiConfig := &ppstructure.APIConfig{
|
||||
Client: client,
|
||||
URL: url,
|
||||
}
|
||||
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel)
|
||||
default:
|
||||
return nil, fmt.Errorf("parser type %s not supported", parserType)
|
||||
}
|
||||
|
||||
return parserManager, nil
|
||||
}
|
||||
|
||||
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
|
||||
vsType := os.Getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
switch vsType {
|
||||
case "milvus":
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
milvusAddr = os.Getenv("MILVUS_ADDR")
|
||||
user = os.Getenv("MILVUS_USER")
|
||||
password = os.Getenv("MILVUS_PASSWORD")
|
||||
milvusToken = os.Getenv("MILVUS_TOKEN")
|
||||
)
|
||||
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
|
||||
Address: milvusAddr,
|
||||
Username: user,
|
||||
Password: password,
|
||||
APIKey: milvusToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
mgr, err := milvus.NewManager(&milvus.ManagerConfig{
|
||||
Client: mc,
|
||||
Embedding: emb,
|
||||
EnableHybrid: ptr.Of(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
case "vikingdb":
|
||||
var (
|
||||
host = os.Getenv("VIKING_DB_HOST")
|
||||
region = os.Getenv("VIKING_DB_REGION")
|
||||
ak = os.Getenv("VIKING_DB_AK")
|
||||
sk = os.Getenv("VIKING_DB_SK")
|
||||
scheme = os.Getenv("VIKING_DB_SCHEME")
|
||||
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
|
||||
)
|
||||
if ak == "" || sk == "" {
|
||||
return nil, fmt.Errorf("invalid vikingdb ak / sk")
|
||||
}
|
||||
if host == "" {
|
||||
host = "api-vikingdb.volces.com"
|
||||
}
|
||||
if region == "" {
|
||||
region = "cn-beijing"
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
var embConfig *vikingdb.VikingEmbeddingConfig
|
||||
if modelName != "" {
|
||||
embName := vikingdb.VikingEmbeddingModelName(modelName)
|
||||
if embName.Dimensions() == 0 {
|
||||
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
|
||||
}
|
||||
embConfig = &vikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: true,
|
||||
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
|
||||
ModelName: embName,
|
||||
ModelVersion: embName.ModelVersion(),
|
||||
DenseWeight: ptr.Of(0.2),
|
||||
BuiltinEmbedding: nil,
|
||||
}
|
||||
} else {
|
||||
builtinEmbedding, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
||||
}
|
||||
|
||||
embConfig = &vikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: false,
|
||||
EnableHybrid: false,
|
||||
BuiltinEmbedding: builtinEmbedding,
|
||||
}
|
||||
}
|
||||
|
||||
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
|
||||
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil, // use default config
|
||||
EmbeddingConfig: embConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
|
||||
case "oceanbase":
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
host = os.Getenv("OCEANBASE_HOST")
|
||||
port = os.Getenv("OCEANBASE_PORT")
|
||||
user = os.Getenv("OCEANBASE_USER")
|
||||
password = os.Getenv("OCEANBASE_PASSWORD")
|
||||
database = os.Getenv("OCEANBASE_DATABASE")
|
||||
)
|
||||
if host == "" || port == "" || user == "" || password == "" || database == "" {
|
||||
return nil, fmt.Errorf("invalid oceanbase configuration: host, port, user, password, database are required")
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
user, password, host, port, database)
|
||||
|
||||
client, err := oceanbaseClient.NewOceanBaseClient(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase client failed, err=%w", err)
|
||||
}
|
||||
|
||||
if err := client.InitDatabase(ctx); err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase database failed, err=%w", err)
|
||||
}
|
||||
|
||||
// Get configuration from environment variables with defaults
|
||||
batchSize := 100
|
||||
if bs := os.Getenv("OCEANBASE_BATCH_SIZE"); bs != "" {
|
||||
if bsInt, err := strconv.Atoi(bs); err == nil {
|
||||
batchSize = bsInt
|
||||
}
|
||||
}
|
||||
|
||||
enableCache := true
|
||||
if ec := os.Getenv("OCEANBASE_ENABLE_CACHE"); ec != "" {
|
||||
if ecBool, err := strconv.ParseBool(ec); err == nil {
|
||||
enableCache = ecBool
|
||||
}
|
||||
}
|
||||
|
||||
cacheTTL := 300 * time.Second
|
||||
if ct := os.Getenv("OCEANBASE_CACHE_TTL"); ct != "" {
|
||||
if ctInt, err := strconv.Atoi(ct); err == nil {
|
||||
cacheTTL = time.Duration(ctInt) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
maxConnections := 100
|
||||
if mc := os.Getenv("OCEANBASE_MAX_CONNECTIONS"); mc != "" {
|
||||
if mcInt, err := strconv.Atoi(mc); err == nil {
|
||||
maxConnections = mcInt
|
||||
}
|
||||
}
|
||||
|
||||
connTimeout := 30 * time.Second
|
||||
if ct := os.Getenv("OCEANBASE_CONN_TIMEOUT"); ct != "" {
|
||||
if ctInt, err := strconv.Atoi(ct); err == nil {
|
||||
connTimeout = time.Duration(ctInt) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
managerConfig := &oceanbase.ManagerConfig{
|
||||
Client: client,
|
||||
Embedding: emb,
|
||||
BatchSize: batchSize,
|
||||
EnableCache: enableCache,
|
||||
CacheTTL: cacheTTL,
|
||||
MaxConnections: maxConnections,
|
||||
ConnTimeout: connTimeout,
|
||||
}
|
||||
mgr, err := oceanbase.NewManager(managerConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase vector store failed, err=%w", err)
|
||||
}
|
||||
return mgr, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
var batchSize int
|
||||
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
|
||||
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
|
||||
batchSize = 100
|
||||
} else {
|
||||
batchSize = int(bs)
|
||||
}
|
||||
|
||||
var emb embedding.Embedder
|
||||
|
||||
switch os.Getenv("EMBEDDING_TYPE") {
|
||||
case "openai":
|
||||
var (
|
||||
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
|
||||
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
|
||||
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
|
||||
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
|
||||
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
|
||||
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
|
||||
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
|
||||
)
|
||||
|
||||
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
|
||||
}
|
||||
|
||||
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: openAIEmbeddingApiKey,
|
||||
ByAzure: byAzure,
|
||||
BaseURL: openAIEmbeddingBaseURL,
|
||||
APIVersion: openAIEmbeddingApiVersion,
|
||||
Model: openAIEmbeddingModel,
|
||||
// Dimensions: ptr.Of(int(dims)),
|
||||
}
|
||||
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
|
||||
if reqDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(reqDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ark":
|
||||
var (
|
||||
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
|
||||
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
|
||||
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
|
||||
// deprecated: use ARK_EMBEDDING_API_KEY instead
|
||||
// ARK_EMBEDDING_AK will be removed in the future
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if arkEmbeddingAPIType != "" {
|
||||
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
|
||||
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
|
||||
} else {
|
||||
apiType = t
|
||||
}
|
||||
}
|
||||
|
||||
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: func() string {
|
||||
if arkEmbeddingApiKey != "" {
|
||||
return arkEmbeddingApiKey
|
||||
}
|
||||
return arkEmbeddingAK
|
||||
}(),
|
||||
Model: arkEmbeddingModel,
|
||||
BaseURL: arkEmbeddingBaseURL,
|
||||
APIType: &apiType,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ollama":
|
||||
var (
|
||||
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
|
||||
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
|
||||
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: ollamaEmbeddingBaseURL,
|
||||
Model: ollamaEmbeddingModel,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
case "gemini":
|
||||
var (
|
||||
geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL")
|
||||
geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL")
|
||||
geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY")
|
||||
geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS")
|
||||
geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI
|
||||
geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT")
|
||||
geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION")
|
||||
)
|
||||
|
||||
if len(geminiEmbeddingModel) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingApiKey) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingDims) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingBackend) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required")
|
||||
}
|
||||
|
||||
dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr)
|
||||
}
|
||||
|
||||
backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr)
|
||||
}
|
||||
|
||||
geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: geminiEmbeddingApiKey,
|
||||
Backend: genai.Backend(backend),
|
||||
Project: geminiEmbeddingProject,
|
||||
Location: geminiEmbeddingLocation,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: geminiEmbeddingBaseURL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
||||
Client: geminiCli,
|
||||
Model: geminiEmbeddingModel,
|
||||
OutputDimensionality: ptr.Of(int32(dims)),
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
||||
}
|
||||
case "http":
|
||||
var (
|
||||
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
|
||||
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
|
||||
)
|
||||
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
|
||||
}
|
||||
emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
return veimagex.NewDefault()
|
||||
}
|
||||
|
||||
@ -31,6 +31,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
// TODO(fanlv) : 模型管理移到 Infra
|
||||
|
||||
func initModelMgr() (modelmgr.Manager, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
|
||||
@ -28,8 +28,6 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
@ -40,7 +38,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
workflowservice "github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner"
|
||||
@ -101,8 +98,8 @@ func InitService(_ context.Context, components *ServiceComponents) (*Application
|
||||
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
||||
wrapPlugin.SetOSS(components.Tos)
|
||||
|
||||
code.SetCodeRunner(components.CodeRunner)
|
||||
callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler())
|
||||
coderunner.SetCodeRunner(components.CodeRunner)
|
||||
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
|
||||
|
||||
setEventBus(components.DomainNotifier)
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package appinfra
|
||||
package buildinmodel
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -22,18 +22,18 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
ao "github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/components/model/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
mo "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
)
|
||||
|
||||
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
||||
func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
||||
getEnv := func(key string) string {
|
||||
if val := os.Getenv(envPrefix + key); val != "" {
|
||||
return val
|
||||
@ -44,14 +44,14 @@ func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.B
|
||||
switch getEnv("BUILTIN_CM_TYPE") {
|
||||
case "openai":
|
||||
byAzure, _ := strconv.ParseBool(getEnv("BUILTIN_CM_OPENAI_BY_AZURE"))
|
||||
bcm, err = mo.NewChatModel(ctx, &mo.ChatModelConfig{
|
||||
bcm, err = openai.NewChatModel(ctx, &openai.ChatModelConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_OPENAI_API_KEY"),
|
||||
ByAzure: byAzure,
|
||||
BaseURL: getEnv("BUILTIN_CM_OPENAI_BASE_URL"),
|
||||
Model: getEnv("BUILTIN_CM_OPENAI_MODEL"),
|
||||
})
|
||||
case "ark":
|
||||
bcm, err = ao.NewChatModel(ctx, &ao.ChatModelConfig{
|
||||
bcm, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_ARK_API_KEY"),
|
||||
Model: getEnv("BUILTIN_CM_ARK_MODEL"),
|
||||
BaseURL: getEnv("BUILTIN_CM_ARK_BASE_URL"),
|
||||
51
backend/bizpkg/fileutil/fileutil.go
Normal file
51
backend/bizpkg/fileutil/fileutil.go
Normal file
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func GetWorkingDirectory() string {
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
||||
root = os.Getenv("PWD")
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func ReadJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
||||
b, err := os.ReadFile(jsonFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m2qMessages []*schema.Message
|
||||
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
||||
for i := range m2qMessages {
|
||||
tpl[i] = m2qMessages[i]
|
||||
}
|
||||
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package goutil
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
@ -1,31 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package code
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner"
|
||||
)
|
||||
|
||||
func GetCodeRunner() coderunner.Runner {
|
||||
return runnerImpl
|
||||
}
|
||||
|
||||
func SetCodeRunner(runner coderunner.Runner) {
|
||||
runnerImpl = runner
|
||||
}
|
||||
|
||||
var runnerImpl coderunner.Runner
|
||||
@ -35,7 +35,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
@ -79,8 +79,7 @@ func (d *databaseTool) Invoke(ctx context.Context, req ExecuteSQLRequest) (strin
|
||||
tableType = table.TableType_DraftTable
|
||||
}
|
||||
|
||||
// TODO(@fanlv): domain 不能依赖具体 impl
|
||||
tableName, err := sqlparser.NewSQLParser().GetTableName(req.SQL)
|
||||
tableName, err := sqlparser.New().GetTableName(req.SQL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@ -35,8 +35,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/events"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document"
|
||||
progressbarContract "github.com/coze-dev/coze-studio/backend/infra/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/progressbar/impl/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
@ -379,9 +378,8 @@ func (k *knowledgeSVC) handleTableDocument(ctx context.Context,
|
||||
func (k *knowledgeSVC) processDocumentChunks(ctx context.Context,
|
||||
doc *entity.Document, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord) error {
|
||||
|
||||
// TODO(@fanlv): domain 不能依赖具体 impl
|
||||
batchSize := 100
|
||||
progressbar := progressbar.NewProgressBar(ctx, doc.ID,
|
||||
progressbar := progressbar.New(ctx, doc.ID,
|
||||
int64(len(parseResult)*len(k.searchStoreManagers)), k.cacheCli, true)
|
||||
|
||||
if err := progressbar.AddN(int(cacheRecord.LastProcessedNumber) * len(k.searchStoreManagers)); err != nil {
|
||||
@ -417,7 +415,7 @@ func (k *knowledgeSVC) finalizeDocumentIndexing(ctx context.Context, knowledgeID
|
||||
// batchProcessSlice processes a batch of document slices
|
||||
func (k *knowledgeSVC) batchProcessSlice(ctx context.Context, doc *entity.Document,
|
||||
startIdx int, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord,
|
||||
progressBar progressbarContract.ProgressBar) error {
|
||||
progressBar progressbar.ProgressBar) error {
|
||||
|
||||
collectionName := getCollectionName(doc.KnowledgeID)
|
||||
length := len(parseResult)
|
||||
@ -640,7 +638,7 @@ func (k *knowledgeSVC) storeSlicesInDB(ctx context.Context, doc *entity.Document
|
||||
// indexSlicesInSearchStores indexes slices in appropriate search stores
|
||||
func (k *knowledgeSVC) indexSlicesInSearchStores(ctx context.Context, doc *entity.Document,
|
||||
collectionName string, sliceEntities []*entity.Slice, cacheRecord *indexDocCacheRecord,
|
||||
progressBar progressbarContract.ProgressBar) error {
|
||||
progressBar progressbar.ProgressBar) error {
|
||||
|
||||
fields, err := k.mapSearchFields(doc)
|
||||
if err != nil {
|
||||
|
||||
@ -47,15 +47,15 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/progressbar/impl/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/rdb/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/storage"
|
||||
@ -552,8 +552,7 @@ func (k *knowledgeSVC) MGetDocumentProgress(ctx context.Context, request *MGetDo
|
||||
}
|
||||
|
||||
func (k *knowledgeSVC) getProgressFromCache(ctx context.Context, documentProgress *DocumentProgress) (err error) {
|
||||
// TODO(@fanlv) : domain 不依赖 impl
|
||||
progressBar := progressbar.NewProgressBar(ctx, documentProgress.ID, 0, k.cacheCli, false)
|
||||
progressBar := progressbar.New(ctx, documentProgress.ID, 0, k.cacheCli, false)
|
||||
percent, remainSec, errMsg := progressBar.GetProgress(ctx)
|
||||
documentProgress.Progress = int(percent)
|
||||
documentProgress.RemainingSec = int64(remainSec)
|
||||
|
||||
@ -36,13 +36,13 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
@ -395,8 +395,7 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||
virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i]
|
||||
}
|
||||
|
||||
// TODO(@fanlv) : domain 不依赖 impl
|
||||
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
|
||||
parsedSQL, err := sqlparser.New().ParseAndModifySQL(sql, replaceMap)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
|
||||
return nil, err
|
||||
@ -466,7 +465,7 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||
const pkID = "_knowledge_slice_id"
|
||||
|
||||
func addSliceIdColumn(originalSql string) string {
|
||||
sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
|
||||
sql, err := sqlparser.New().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
|
||||
if err != nil {
|
||||
logs.Errorf("add slice id column failed: %v", err)
|
||||
return originalSql
|
||||
|
||||
@ -36,6 +36,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
rdb_entity "github.com/coze-dev/coze-studio/backend/infra/rdb/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/nl2sql_mock"
|
||||
mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
@ -59,6 +61,8 @@ func TestAddSliceIdColumn(t *testing.T) {
|
||||
expected: "SELECT FROM users",
|
||||
},
|
||||
}
|
||||
sqlparser.New = sqlparserImpl.NewSQLParser
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := addSliceIdColumn(tt.input)
|
||||
|
||||
@ -46,8 +46,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
entity3 "github.com/coze-dev/coze-studio/backend/infra/rdb/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
@ -1044,8 +1044,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
||||
return nil, fmt.Errorf("SQL is empty")
|
||||
}
|
||||
|
||||
// TODO(@fanlv) : domain 不依赖 impl
|
||||
operation, err := sqlparser.NewSQLParser().GetSQLOperation(*req.SQL)
|
||||
operation, err := sqlparser.New().GetSQLOperation(*req.SQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1072,7 +1071,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
||||
},
|
||||
}
|
||||
|
||||
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(*req.SQL, tableColumnMapping)
|
||||
parsedSQL, err := sqlparser.New().ParseAndModifySQL(*req.SQL, tableColumnMapping)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse sql failed: %v", err)
|
||||
}
|
||||
@ -1080,7 +1079,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
||||
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 {
|
||||
switch operation {
|
||||
case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete:
|
||||
parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID))
|
||||
parsedSQL, err = sqlparser.New().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("append sql filter failed: %v", err)
|
||||
}
|
||||
@ -1092,7 +1091,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
||||
if req.ConnectorID != nil {
|
||||
cid = *req.ConnectorID
|
||||
}
|
||||
nums, err := sqlparser.NewSQLParser().GetInsertDataNums(parsedSQL)
|
||||
nums, err := sqlparser.New().GetInsertDataNums(parsedSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1114,7 +1113,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
||||
for i, id := range ids {
|
||||
iIDs[i] = id
|
||||
}
|
||||
parsedSQL, _, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
|
||||
parsedSQL, _, err = sqlparser.New().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
|
||||
{
|
||||
ColName: database.DefaultCidColName,
|
||||
Value: cid,
|
||||
@ -1128,7 +1127,7 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
||||
return nil, fmt.Errorf("add columns to insert sql failed: %v", err)
|
||||
}
|
||||
} else if req.SQLType == database.SQLType_Parameterized {
|
||||
parsedSQL, existingCols, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
|
||||
parsedSQL, existingCols, err = sqlparser.New().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
|
||||
{
|
||||
ColName: database.DefaultCidColName,
|
||||
},
|
||||
|
||||
@ -37,6 +37,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
rdb2 "github.com/coze-dev/coze-studio/backend/infra/rdb/impl/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/idgen"
|
||||
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
@ -44,6 +46,8 @@ import (
|
||||
)
|
||||
|
||||
func setupTestEnv(t *testing.T) (*gorm.DB, rdb.RDB, *mock.MockIDGenerator, repository.DraftDAO, repository.OnlineDAO, Database) {
|
||||
sqlparser.New = sqlparserImpl.NewSQLParser
|
||||
|
||||
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
if os.Getenv("CI_JOB_NAME") != "" {
|
||||
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
|
||||
|
||||
@ -45,7 +45,6 @@ import (
|
||||
mockmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr/modelmock"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
@ -768,7 +767,7 @@ func TestCodeAndPluginNodes(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCodeRunner := mockcode.NewMockRunner(ctrl)
|
||||
mockey.Mock(code.GetCodeRunner).Return(mockCodeRunner).Build()
|
||||
mockey.Mock(coderunner.GetCodeRunner).Return(mockCodeRunner).Build()
|
||||
|
||||
mockRepo := mockWorkflow.NewMockRepository(ctrl)
|
||||
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
code2 "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
@ -169,7 +168,7 @@ func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.Bui
|
||||
code: c.Code,
|
||||
language: c.Language,
|
||||
outputConfig: ns.OutputTypes,
|
||||
runner: code2.GetCodeRunner(),
|
||||
runner: coderunner.GetCodeRunner(),
|
||||
importError: importErr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
2
backend/infra/cache/impl/redis/redis.go
vendored
2
backend/infra/cache/impl/redis/redis.go
vendored
@ -26,6 +26,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/cache"
|
||||
)
|
||||
|
||||
type Cmdable = cache.Cmdable
|
||||
|
||||
func New() cache.Cmdable {
|
||||
addr := os.Getenv("REDIS_ADDR")
|
||||
password := os.Getenv("REDIS_PASSWORD")
|
||||
|
||||
@ -38,3 +38,13 @@ type RunResponse struct {
|
||||
type Runner interface {
|
||||
Run(ctx context.Context, request *RunRequest) (*RunResponse, error)
|
||||
}
|
||||
|
||||
func GetCodeRunner() Runner {
|
||||
return runnerImpl
|
||||
}
|
||||
|
||||
func SetCodeRunner(runner Runner) {
|
||||
runnerImpl = runner
|
||||
}
|
||||
|
||||
var runnerImpl Runner
|
||||
|
||||
@ -22,8 +22,8 @@ import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@ -78,7 +78,7 @@ func (r *runner) pythonCmdRun(_ context.Context, code string, params map[string]
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal params to json, err: %w", err)
|
||||
}
|
||||
cmd := exec.Command(goutil.GetPython3Path(), "-c", fmt.Sprintf(pythonCode, code), string(bs)) // ignore_security_alert RCE
|
||||
cmd := exec.Command(fileutil.GetPython3Path(), "-c", fmt.Sprintf(pythonCode, code), string(bs)) // ignore_security_alert RCE
|
||||
stdout := new(bytes.Buffer)
|
||||
stderr := new(bytes.Buffer)
|
||||
cmd.Stdout = stdout
|
||||
|
||||
66
backend/infra/coderunner/impl/impl.go
Normal file
66
backend/infra/coderunner/impl/impl.go
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/direct"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner/impl/sandbox"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type Runner = coderunner.Runner
|
||||
|
||||
func New() Runner {
|
||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
||||
case "sandbox":
|
||||
getAndSplit := func(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(v, ",")
|
||||
}
|
||||
config := &sandbox.Config{
|
||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
|
||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
|
||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
|
||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
|
||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
|
||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
|
||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||
TimeoutSeconds: 0,
|
||||
MemoryLimitMB: 0,
|
||||
}
|
||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
|
||||
config.TimeoutSeconds = f
|
||||
} else {
|
||||
config.TimeoutSeconds = 60.0
|
||||
}
|
||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
|
||||
config.MemoryLimitMB = mem
|
||||
} else {
|
||||
config.MemoryLimitMB = 100
|
||||
}
|
||||
return sandbox.NewRunner(config)
|
||||
default:
|
||||
return direct.NewRunner()
|
||||
}
|
||||
}
|
||||
@ -23,15 +23,15 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func NewRunner(config *Config) coderunner.Runner {
|
||||
return &runner{
|
||||
pyPath: goutil.GetPython3Path(),
|
||||
scriptPath: goutil.GetPythonFilePath("sandbox.py"),
|
||||
pyPath: fileutil.GetPython3Path(),
|
||||
scriptPath: fileutil.GetPythonFilePath("sandbox.py"),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,7 +26,7 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/messages2query"
|
||||
)
|
||||
|
||||
func NewMessagesToQuery(_ context.Context, model chatmodel.BaseChatModel, template prompt.ChatTemplate) (messages2query.MessagesToQuery, error) {
|
||||
48
backend/infra/document/messages2query/impl/impl.go
Normal file
48
backend/infra/document/messages2query/impl/impl.go
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/messages2query/impl/builtin"
|
||||
)
|
||||
|
||||
type MessagesToQuery = messages2query.MessagesToQuery
|
||||
|
||||
func New(ctx context.Context) (MessagesToQuery, error) {
|
||||
rewriterChatModel, _, err := buildinmodel.GetBuiltinChatModel(ctx, "M2Q_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(fileutil.GetWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
||||
rewriterTemplate, err := fileutil.ReadJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rewriter, err := builtin.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rewriter, nil
|
||||
}
|
||||
48
backend/infra/document/nl2sql/impl/impl.go
Normal file
48
backend/infra/document/nl2sql/impl/impl.go
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/nl2sql/impl/builtin"
|
||||
)
|
||||
|
||||
type NL2SQL = nl2sql.NL2SQL
|
||||
|
||||
func New(ctx context.Context) (nl2sql.NL2SQL, error) {
|
||||
n2sChatModel, _, err := buildinmodel.GetBuiltinChatModel(ctx, "NL2SQL_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(fileutil.GetWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json")
|
||||
n2sTemplate, err := fileutil.ReadJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n2s, err := builtin.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return n2s, nil
|
||||
}
|
||||
55
backend/infra/document/ocr/impl/impl.go
Normal file
55
backend/infra/document/ocr/impl/impl.go
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/ppocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr/impl/veocr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type OCR = ocr.OCR
|
||||
|
||||
func New() ocr.OCR {
|
||||
var ocr ocr.OCR
|
||||
switch os.Getenv(consts.OCRType) {
|
||||
case "ve":
|
||||
ocrAK := os.Getenv(consts.VeOCRAK)
|
||||
ocrSK := os.Getenv(consts.VeOCRSK)
|
||||
if ocrAK == "" || ocrSK == "" {
|
||||
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
|
||||
}
|
||||
inst := visual.NewInstance()
|
||||
inst.Client.SetAccessKey(ocrAK)
|
||||
inst.Client.SetSecretKey(ocrSK)
|
||||
ocr = veocr.NewOCR(&veocr.Config{Client: inst})
|
||||
case "paddleocr":
|
||||
url := os.Getenv(consts.PPOCRAPIURL)
|
||||
client := &http.Client{}
|
||||
ocr = ppocr.NewOCR(&ppocr.Config{Client: client, URL: url})
|
||||
default:
|
||||
// accept ocr not configured
|
||||
}
|
||||
|
||||
return ocr
|
||||
}
|
||||
@ -19,11 +19,11 @@ package builtin
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
)
|
||||
|
||||
func NewManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager {
|
||||
@ -52,13 +52,13 @@ func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) {
|
||||
|
||||
switch config.FileExtension {
|
||||
case parser.FileExtensionPDF:
|
||||
pFn = ParseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_pdf.py"))
|
||||
pFn = ParseByPython(config, m.storage, m.ocr, fileutil.GetPython3Path(), fileutil.GetPythonFilePath("parse_pdf.py"))
|
||||
case parser.FileExtensionTXT:
|
||||
pFn = ParseText(config)
|
||||
case parser.FileExtensionMarkdown:
|
||||
pFn = ParseMarkdown(config, m.storage, m.ocr)
|
||||
case parser.FileExtensionDocx:
|
||||
pFn = ParseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py"))
|
||||
pFn = ParseByPython(config, m.storage, m.ocr, fileutil.GetPython3Path(), fileutil.GetPythonFilePath("parse_docx.py"))
|
||||
case parser.FileExtensionCSV:
|
||||
pFn = ParseCSV(config)
|
||||
case parser.FileExtensionXLSX:
|
||||
|
||||
59
backend/infra/document/parser/impl/impl.go
Normal file
59
backend/infra/document/parser/impl/impl.go
Normal file
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/buildinmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/ppstructure"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type Manager = parser.Manager
|
||||
|
||||
func New(ctx context.Context, storage storage.Storage, ocr ocr.OCR) (Manager, error) {
|
||||
imageAnnotationModel, _, err := buildinmodel.GetBuiltinChatModel(ctx, "IA_")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
|
||||
}
|
||||
|
||||
var parserManager parser.Manager
|
||||
parserType := os.Getenv(consts.ParserType)
|
||||
switch parserType {
|
||||
case "builtin", "":
|
||||
parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel)
|
||||
case "paddleocr":
|
||||
url := os.Getenv(consts.PPStructureAPIURL)
|
||||
client := &http.Client{}
|
||||
apiConfig := &ppstructure.APIConfig{
|
||||
Client: client,
|
||||
URL: url,
|
||||
}
|
||||
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel)
|
||||
default:
|
||||
return nil, fmt.Errorf("parser type %s not supported", parserType)
|
||||
}
|
||||
|
||||
return parserManager, nil
|
||||
}
|
||||
@ -19,12 +19,12 @@ package ppstructure
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/fileutil"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/parser/impl/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
)
|
||||
|
||||
func NewManager(apiConfig *APIConfig, ocr ocr.OCR, storage storage.Storage, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager {
|
||||
@ -64,7 +64,7 @@ func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) {
|
||||
pFn = builtin.ParseMarkdown(config, m.storage, m.ocr)
|
||||
return &builtin.Parser{ParseFn: pFn}, nil
|
||||
case parser.FileExtensionDocx:
|
||||
pFn = builtin.ParseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py"))
|
||||
pFn = builtin.ParseByPython(config, m.storage, m.ocr, fileutil.GetPython3Path(), fileutil.GetPythonFilePath("parse_docx.py"))
|
||||
return &builtin.Parser{ParseFn: pFn}, nil
|
||||
case parser.FileExtensionCSV:
|
||||
pFn = builtin.ParseCSV(config)
|
||||
|
||||
@ -42,8 +42,8 @@ const (
|
||||
ProgressBarTotalNumRedisKey = "RedisBiz.Knowledge_ProgressBar_TotalNum_%d"
|
||||
ProgressBarProcessedNumRedisKey = "RedisBiz.Knowledge_ProgressBar_ProcessedNum_%d"
|
||||
DefaultProcessTime = 300
|
||||
ProcessDone = 100
|
||||
ProcessInit = 0
|
||||
ProcessDone = progressbar.ProcessDone
|
||||
ProcessInit = progressbar.ProcessInit
|
||||
)
|
||||
|
||||
func NewProgressBar(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) progressbar.ProgressBar {
|
||||
@ -142,6 +142,9 @@ func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainS
|
||||
if ptr.From(startTime) == 0 {
|
||||
remainSec = DefaultProcessTime
|
||||
} else {
|
||||
if ptr.From(processedNum) == 0 {
|
||||
return
|
||||
}
|
||||
usedSec := time.Now().Unix() - ptr.From(startTime)
|
||||
remainSec = int(float64(ptr.From(totalNum)-ptr.From(processedNum)) / float64(ptr.From(processedNum)) * float64(usedSec))
|
||||
}
|
||||
|
||||
@ -16,7 +16,11 @@
|
||||
|
||||
package progressbar
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/cache"
|
||||
)
|
||||
|
||||
// ProgressBar is the interface for the progress bar.
|
||||
type ProgressBar interface {
|
||||
@ -24,3 +28,10 @@ type ProgressBar interface {
|
||||
ReportError(err error) error
|
||||
GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string)
|
||||
}
|
||||
|
||||
var New func(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) ProgressBar
|
||||
|
||||
const (
|
||||
ProcessDone = 100
|
||||
ProcessInit = 0
|
||||
)
|
||||
|
||||
48
backend/infra/document/rerank/impl/impl.go
Normal file
48
backend/infra/document/rerank/impl/impl.go
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/rrf"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/rerank/impl/vikingdb"
|
||||
)
|
||||
|
||||
type Reranker = rerank.Reranker
|
||||
|
||||
func New() Reranker {
|
||||
rerankerType := os.Getenv("RERANK_TYPE")
|
||||
switch rerankerType {
|
||||
case "vikingdb":
|
||||
return vikingdb.NewReranker(getVikingRerankerConfig())
|
||||
case "rrf":
|
||||
return rrf.NewRRFReranker(0)
|
||||
default:
|
||||
return rrf.NewRRFReranker(0)
|
||||
}
|
||||
}
|
||||
|
||||
func getVikingRerankerConfig() *vikingdb.Config {
|
||||
return &vikingdb.Config{
|
||||
AK: os.Getenv("VIKINGDB_RERANK_AK"),
|
||||
SK: os.Getenv("VIKINGDB_RERANK_SK"),
|
||||
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
|
||||
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
|
||||
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
|
||||
}
|
||||
}
|
||||
433
backend/infra/document/searchstore/impl/impl.go
Normal file
433
backend/infra/document/searchstore/impl/impl.go
Normal file
@ -0,0 +1,433 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/elasticsearch"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/milvus"
|
||||
searchstoreOceanbase "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/es/impl/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type Manager = searchstore.Manager
|
||||
|
||||
func New(ctx context.Context, es es.Client) ([]Manager, error) {
|
||||
// es full text search
|
||||
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es})
|
||||
|
||||
// vector search
|
||||
mgr, err := getVectorStore(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
||||
}
|
||||
|
||||
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
|
||||
vsType := os.Getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
switch vsType {
|
||||
case "milvus":
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
milvusAddr = os.Getenv("MILVUS_ADDR")
|
||||
user = os.Getenv("MILVUS_USER")
|
||||
password = os.Getenv("MILVUS_PASSWORD")
|
||||
milvusToken = os.Getenv("MILVUS_TOKEN")
|
||||
)
|
||||
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
|
||||
Address: milvusAddr,
|
||||
Username: user,
|
||||
Password: password,
|
||||
APIKey: milvusToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
mgr, err := milvus.NewManager(&milvus.ManagerConfig{
|
||||
Client: mc,
|
||||
Embedding: emb,
|
||||
EnableHybrid: ptr.Of(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
case "vikingdb":
|
||||
var (
|
||||
host = os.Getenv("VIKING_DB_HOST")
|
||||
region = os.Getenv("VIKING_DB_REGION")
|
||||
ak = os.Getenv("VIKING_DB_AK")
|
||||
sk = os.Getenv("VIKING_DB_SK")
|
||||
scheme = os.Getenv("VIKING_DB_SCHEME")
|
||||
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
|
||||
)
|
||||
if ak == "" || sk == "" {
|
||||
return nil, fmt.Errorf("invalid vikingdb ak / sk")
|
||||
}
|
||||
if host == "" {
|
||||
host = "api-vikingdb.volces.com"
|
||||
}
|
||||
if region == "" {
|
||||
region = "cn-beijing"
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
var embConfig *vikingdb.VikingEmbeddingConfig
|
||||
if modelName != "" {
|
||||
embName := vikingdb.VikingEmbeddingModelName(modelName)
|
||||
if embName.Dimensions() == 0 {
|
||||
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
|
||||
}
|
||||
embConfig = &vikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: true,
|
||||
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
|
||||
ModelName: embName,
|
||||
ModelVersion: embName.ModelVersion(),
|
||||
DenseWeight: ptr.Of(0.2),
|
||||
BuiltinEmbedding: nil,
|
||||
}
|
||||
} else {
|
||||
builtinEmbedding, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
||||
}
|
||||
|
||||
embConfig = &vikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: false,
|
||||
EnableHybrid: false,
|
||||
BuiltinEmbedding: builtinEmbedding,
|
||||
}
|
||||
}
|
||||
|
||||
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
|
||||
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil, // use default config
|
||||
EmbeddingConfig: embConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
|
||||
case "oceanbase":
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
host = os.Getenv("OCEANBASE_HOST")
|
||||
port = os.Getenv("OCEANBASE_PORT")
|
||||
user = os.Getenv("OCEANBASE_USER")
|
||||
password = os.Getenv("OCEANBASE_PASSWORD")
|
||||
database = os.Getenv("OCEANBASE_DATABASE")
|
||||
)
|
||||
if host == "" || port == "" || user == "" || password == "" || database == "" {
|
||||
return nil, fmt.Errorf("invalid oceanbase configuration: host, port, user, password, database are required")
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
user, password, host, port, database)
|
||||
|
||||
client, err := oceanbase.NewOceanBaseClient(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase client failed, err=%w", err)
|
||||
}
|
||||
|
||||
if err := client.InitDatabase(ctx); err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase database failed, err=%w", err)
|
||||
}
|
||||
|
||||
// Get configuration from environment variables with defaults
|
||||
batchSize := 100
|
||||
if bs := os.Getenv("OCEANBASE_BATCH_SIZE"); bs != "" {
|
||||
if bsInt, err := strconv.Atoi(bs); err == nil {
|
||||
batchSize = bsInt
|
||||
}
|
||||
}
|
||||
|
||||
enableCache := true
|
||||
if ec := os.Getenv("OCEANBASE_ENABLE_CACHE"); ec != "" {
|
||||
if ecBool, err := strconv.ParseBool(ec); err == nil {
|
||||
enableCache = ecBool
|
||||
}
|
||||
}
|
||||
|
||||
cacheTTL := 300 * time.Second
|
||||
if ct := os.Getenv("OCEANBASE_CACHE_TTL"); ct != "" {
|
||||
if ctInt, err := strconv.Atoi(ct); err == nil {
|
||||
cacheTTL = time.Duration(ctInt) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
maxConnections := 100
|
||||
if mc := os.Getenv("OCEANBASE_MAX_CONNECTIONS"); mc != "" {
|
||||
if mcInt, err := strconv.Atoi(mc); err == nil {
|
||||
maxConnections = mcInt
|
||||
}
|
||||
}
|
||||
|
||||
connTimeout := 30 * time.Second
|
||||
if ct := os.Getenv("OCEANBASE_CONN_TIMEOUT"); ct != "" {
|
||||
if ctInt, err := strconv.Atoi(ct); err == nil {
|
||||
connTimeout = time.Duration(ctInt) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
managerConfig := &searchstoreOceanbase.ManagerConfig{
|
||||
Client: client,
|
||||
Embedding: emb,
|
||||
BatchSize: batchSize,
|
||||
EnableCache: enableCache,
|
||||
CacheTTL: cacheTTL,
|
||||
MaxConnections: maxConnections,
|
||||
ConnTimeout: connTimeout,
|
||||
}
|
||||
mgr, err := searchstoreOceanbase.NewManager(managerConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase vector store failed, err=%w", err)
|
||||
}
|
||||
return mgr, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
var batchSize int
|
||||
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
|
||||
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
|
||||
batchSize = 100
|
||||
} else {
|
||||
batchSize = int(bs)
|
||||
}
|
||||
|
||||
var emb embedding.Embedder
|
||||
|
||||
switch os.Getenv("EMBEDDING_TYPE") {
|
||||
case "openai":
|
||||
var (
|
||||
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
|
||||
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
|
||||
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
|
||||
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
|
||||
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
|
||||
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
|
||||
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
|
||||
)
|
||||
|
||||
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
|
||||
}
|
||||
|
||||
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: openAIEmbeddingApiKey,
|
||||
ByAzure: byAzure,
|
||||
BaseURL: openAIEmbeddingBaseURL,
|
||||
APIVersion: openAIEmbeddingApiVersion,
|
||||
Model: openAIEmbeddingModel,
|
||||
// Dimensions: ptr.Of(int(dims)),
|
||||
}
|
||||
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
|
||||
if reqDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(reqDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ark":
|
||||
var (
|
||||
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
|
||||
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
|
||||
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
|
||||
// deprecated: use ARK_EMBEDDING_API_KEY instead
|
||||
// ARK_EMBEDDING_AK will be removed in the future
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if arkEmbeddingAPIType != "" {
|
||||
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
|
||||
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
|
||||
} else {
|
||||
apiType = t
|
||||
}
|
||||
}
|
||||
|
||||
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: func() string {
|
||||
if arkEmbeddingApiKey != "" {
|
||||
return arkEmbeddingApiKey
|
||||
}
|
||||
return arkEmbeddingAK
|
||||
}(),
|
||||
Model: arkEmbeddingModel,
|
||||
BaseURL: arkEmbeddingBaseURL,
|
||||
APIType: &apiType,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ollama":
|
||||
var (
|
||||
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
|
||||
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
|
||||
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: ollamaEmbeddingBaseURL,
|
||||
Model: ollamaEmbeddingModel,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
case "gemini":
|
||||
var (
|
||||
geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL")
|
||||
geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL")
|
||||
geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY")
|
||||
geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS")
|
||||
geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI
|
||||
geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT")
|
||||
geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION")
|
||||
)
|
||||
|
||||
if len(geminiEmbeddingModel) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingApiKey) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingDims) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingBackend) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required")
|
||||
}
|
||||
|
||||
dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr)
|
||||
}
|
||||
|
||||
backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr)
|
||||
}
|
||||
|
||||
geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: geminiEmbeddingApiKey,
|
||||
Backend: genai.Backend(backend),
|
||||
Project: geminiEmbeddingProject,
|
||||
Location: geminiEmbeddingLocation,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: geminiEmbeddingBaseURL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
||||
Client: geminiCli,
|
||||
Model: geminiEmbeddingModel,
|
||||
OutputDimensionality: ptr.Of(int32(dims)),
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
||||
}
|
||||
case "http":
|
||||
var (
|
||||
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
|
||||
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
|
||||
)
|
||||
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
|
||||
}
|
||||
emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
@ -18,7 +18,7 @@ package eventbus
|
||||
|
||||
import "context"
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/eventbus/eventbus_mock.go -package mock -source eventbus.go Factory
|
||||
//go:generate mockgen -destination ../../internal/mock/infra/eventbus/eventbus_mock.go -package mock -source eventbus.go Factory
|
||||
type Producer interface {
|
||||
Send(ctx context.Context, body []byte, opts ...SendOpt) error
|
||||
BatchSend(ctx context.Context, bodyArr [][]byte, opts ...SendOpt) error
|
||||
|
||||
@ -77,3 +77,35 @@ func NewProducer(nameServer, topic, group string, retries int) (eventbus.Produce
|
||||
|
||||
return nil, fmt.Errorf("invalid mq type: %s , only support nsq, kafka, rmq, pulsar", tp)
|
||||
}
|
||||
|
||||
func InitResourceEventBusProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
resourceEventBusProducer, err := NewProducer(nameServer,
|
||||
consts.RMQTopicResource, consts.RMQConsumeGroupResource, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init resource producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return resourceEventBusProducer, nil
|
||||
}
|
||||
|
||||
func InitAppEventProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
appEventProducer, err := NewProducer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init app producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return appEventProducer, nil
|
||||
}
|
||||
|
||||
func InitKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
|
||||
knowledgeProducer, err := NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return knowledgeProducer, nil
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package veimagex
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
@ -26,8 +27,20 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func NewDefault() (imagex.ImageX, error) {
|
||||
return New(
|
||||
os.Getenv(consts.VeImageXAK),
|
||||
os.Getenv(consts.VeImageXSK),
|
||||
os.Getenv(consts.VeImageXDomain),
|
||||
os.Getenv(consts.VeImageXUploadHost),
|
||||
os.Getenv(consts.VeImageXTemplate),
|
||||
[]string{os.Getenv(consts.VeImageXServerID)},
|
||||
)
|
||||
}
|
||||
|
||||
func New(ak, sk, domain, uploadHost, template string, serverIDs []string) (imagex.ImageX, error) {
|
||||
instance := veimagex.DefaultInstance
|
||||
instance.SetCredential(base.Credentials{
|
||||
|
||||
@ -27,8 +27,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb/entity"
|
||||
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
@ -657,12 +656,12 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques
|
||||
}
|
||||
}
|
||||
|
||||
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
||||
operation, err := sqlparser.New().GetSQLOperation(processedSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if operation != sqlparsercontract.OperationTypeSelect {
|
||||
if operation != sqlparser.OperationTypeSelect {
|
||||
result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
|
||||
|
||||
@ -30,6 +30,8 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/rdb"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/infra/rdb/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/sqlparser"
|
||||
sqlparserImpl "github.com/coze-dev/coze-studio/backend/infra/sqlparser/impl/sqlparser"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
@ -515,6 +517,8 @@ func TestSelectData(t *testing.T) {
|
||||
|
||||
func TestExecuteSQL(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
sqlparser.New = sqlparserImpl.NewSQLParser
|
||||
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_sql_table")
|
||||
|
||||
|
||||
@ -78,3 +78,5 @@ type SQLParser interface {
|
||||
// AddSelectFieldsToSelectSQL add select fields to select sql
|
||||
AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error)
|
||||
}
|
||||
|
||||
var New func() SQLParser
|
||||
|
||||
@ -50,7 +50,7 @@ func New(ctx context.Context, endpoint, accessKeyID, secretAccessKey, bucketName
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func getMinioClient(_ context.Context, endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*minioClient, error) {
|
||||
func getMinioClient(ctx context.Context, endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*minioClient, error) {
|
||||
client, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: useSSL,
|
||||
@ -67,7 +67,7 @@ func getMinioClient(_ context.Context, endpoint, accessKeyID, secretAccessKey, b
|
||||
endpoint: endpoint,
|
||||
}
|
||||
|
||||
err = m.createBucketIfNeed(context.Background(), client, bucketName, "cn-north-1")
|
||||
err = m.createBucketIfNeed(ctx, client, bucketName, "cn-north-1")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init minio client failed %v", err)
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ func getS3Client(ctx context.Context, ak, sk, bucketName, endpoint, region strin
|
||||
}, nil
|
||||
})
|
||||
cfg, err := config.LoadDefaultConfig(
|
||||
context.TODO(),
|
||||
ctx,
|
||||
config.WithCredentialsProvider(creds),
|
||||
config.WithEndpointResolverWithOptions(customResolver),
|
||||
config.WithRegion("auto"),
|
||||
|
||||
@ -74,7 +74,8 @@ func (g *JsonCache[T]) Get(ctx context.Context, k string) (*T, error) {
|
||||
}
|
||||
|
||||
func (g *JsonCache[T]) Delete(ctx context.Context, k string) error {
|
||||
if err := g.cache.Del(ctx, k).Err(); err != nil {
|
||||
key := g.prefix + k
|
||||
if err := g.cache.Del(ctx, key).Err(); err != nil {
|
||||
return fmt.Errorf("failed to delete key %s: %w", k, err)
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user