Files
coze-studio/backend/infra/embedding/impl/impl.go
Ryo b3d6357bbd feat: model configuration and embedding configuration optimization. (#2414)
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-10-30 06:05:23 +00:00

135 lines
4.3 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package impl
import (
"context"
"fmt"
"github.com/cloudwego/eino-ext/components/embedding/gemini"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"google.golang.org/genai"
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
"github.com/coze-dev/coze-studio/backend/infra/embedding"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func GetEmbedding(ctx context.Context, cfg *config.EmbeddingConfig) (embedding.Embedder, error) {
var (
emb embedding.Embedder
err error
connInfo = cfg.Connection.BaseConnInfo
embeddingInfo = cfg.Connection.EmbeddingInfo
)
switch cfg.Type {
case config.EmbeddingType_OpenAI:
openaiConnCfg := cfg.Connection.Openai
openAICfg := &openai.EmbeddingConfig{
APIKey: connInfo.APIKey,
BaseURL: connInfo.BaseURL,
Model: connInfo.Model,
ByAzure: openaiConnCfg.ByAzure,
APIVersion: openaiConnCfg.APIVersion,
}
if embeddingInfo.Dims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(embeddingInfo.Dims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case config.EmbeddingType_Ark:
arkCfg := cfg.Connection.Ark
apiType := ark.APITypeText
if ark.APIType(arkCfg.APIType) == ark.APITypeMultiModal {
apiType = ark.APITypeMultiModal
}
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: connInfo.APIKey,
Model: connInfo.Model,
BaseURL: connInfo.BaseURL,
APIType: &apiType,
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case config.EmbeddingType_Ollama:
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: connInfo.BaseURL,
Model: connInfo.Model,
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case config.EmbeddingType_Gemini:
geminiCfg := cfg.Connection.Gemini
if len(connInfo.Model) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
}
if len(connInfo.APIKey) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
}
geminiCli, err1 := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: connInfo.APIKey,
Backend: genai.Backend(geminiCfg.Backend),
Project: geminiCfg.Project,
Location: geminiCfg.Location,
HTTPOptions: genai.HTTPOptions{
BaseURL: connInfo.BaseURL,
},
})
if err1 != nil {
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
}
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
Client: geminiCli,
Model: connInfo.Model,
OutputDimensionality: ptr.Of(int32(embeddingInfo.Dims)),
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
}
case config.EmbeddingType_HTTP:
httpCfg := cfg.Connection.HTTP
emb, err = http.NewEmbedding(httpCfg.Address, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}