Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
135 lines
4.3 KiB
Go
135 lines
4.3 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package impl
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
|
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
|
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
|
"google.golang.org/genai"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
)
|
|
|
|
func GetEmbedding(ctx context.Context, cfg *config.EmbeddingConfig) (embedding.Embedder, error) {
|
|
var (
|
|
emb embedding.Embedder
|
|
err error
|
|
connInfo = cfg.Connection.BaseConnInfo
|
|
embeddingInfo = cfg.Connection.EmbeddingInfo
|
|
)
|
|
|
|
switch cfg.Type {
|
|
case config.EmbeddingType_OpenAI:
|
|
openaiConnCfg := cfg.Connection.Openai
|
|
openAICfg := &openai.EmbeddingConfig{
|
|
APIKey: connInfo.APIKey,
|
|
BaseURL: connInfo.BaseURL,
|
|
Model: connInfo.Model,
|
|
ByAzure: openaiConnCfg.ByAzure,
|
|
APIVersion: openaiConnCfg.APIVersion,
|
|
}
|
|
|
|
if embeddingInfo.Dims > 0 {
|
|
// some openai model not support request dims
|
|
openAICfg.Dimensions = ptr.Of(int(embeddingInfo.Dims))
|
|
}
|
|
|
|
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
|
}
|
|
case config.EmbeddingType_Ark:
|
|
arkCfg := cfg.Connection.Ark
|
|
|
|
apiType := ark.APITypeText
|
|
if ark.APIType(arkCfg.APIType) == ark.APITypeMultiModal {
|
|
apiType = ark.APITypeMultiModal
|
|
}
|
|
|
|
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
|
APIKey: connInfo.APIKey,
|
|
Model: connInfo.Model,
|
|
BaseURL: connInfo.BaseURL,
|
|
APIType: &apiType,
|
|
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
|
}
|
|
|
|
case config.EmbeddingType_Ollama:
|
|
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
|
|
BaseURL: connInfo.BaseURL,
|
|
Model: connInfo.Model,
|
|
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
|
}
|
|
case config.EmbeddingType_Gemini:
|
|
geminiCfg := cfg.Connection.Gemini
|
|
|
|
if len(connInfo.Model) == 0 {
|
|
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
|
}
|
|
if len(connInfo.APIKey) == 0 {
|
|
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
|
}
|
|
|
|
geminiCli, err1 := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: connInfo.APIKey,
|
|
Backend: genai.Backend(geminiCfg.Backend),
|
|
Project: geminiCfg.Project,
|
|
Location: geminiCfg.Location,
|
|
HTTPOptions: genai.HTTPOptions{
|
|
BaseURL: connInfo.BaseURL,
|
|
},
|
|
})
|
|
if err1 != nil {
|
|
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
|
}
|
|
|
|
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
|
Client: geminiCli,
|
|
Model: connInfo.Model,
|
|
OutputDimensionality: ptr.Of(int32(embeddingInfo.Dims)),
|
|
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
|
}
|
|
case config.EmbeddingType_HTTP:
|
|
httpCfg := cfg.Connection.HTTP
|
|
|
|
emb, err = http.NewEmbedding(httpCfg.Address, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
|
}
|
|
|
|
default:
|
|
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
|
}
|
|
|
|
return emb, nil
|
|
}
|