feat: optimize embedding configuration to support omitting model dims (#2412)

This commit is contained in:
Ryo
2025-10-29 20:31:03 +08:00
committed by GitHub
parent bfd69ac17d
commit 7335e0a24c
7 changed files with 322 additions and 153 deletions

View File

@ -18,6 +18,7 @@ import (
bizConf "github.com/coze-dev/coze-studio/backend/bizpkg/config"
"github.com/coze-dev/coze-studio/backend/bizpkg/config/modelmgr"
"github.com/coze-dev/coze-studio/backend/bizpkg/llm/modelbuilder"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@ -131,6 +132,41 @@ func UpdateKnowledgeConfig(ctx context.Context, c *app.RequestContext) {
return
}
embedding, err := impl.GetEmbedding(ctx, req.KnowledgeConfig.EmbeddingConfig)
if err != nil {
invalidParamRequestResponse(c, fmt.Sprintf("get embedding failed: %v", err))
return
}
if req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims == 0 {
req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims = int32(embedding.Dimensions())
embedding, err = impl.GetEmbedding(ctx, req.KnowledgeConfig.EmbeddingConfig)
if err != nil {
invalidParamRequestResponse(c, fmt.Sprintf("get embedding failed: %v", err))
return
}
}
denseEmbeddings, err := embedding.EmbedStrings(ctx, []string{"test"})
if err != nil {
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: %v", err))
return
}
if len(denseEmbeddings) == 0 {
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: %v", err))
return
}
logs.CtxDebugf(ctx, "embed test string result: %d, expect %d",
len(denseEmbeddings[0]), req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims)
if len(denseEmbeddings[0]) != int(req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims) {
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: dims not match, expect %d, got %d",
req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims, len(denseEmbeddings[0])))
return
}
err = bizConf.Knowledge().SaveKnowledgeConfig(ctx, req.KnowledgeConfig)
if err != nil {
internalServerErrorResponse(ctx, c, fmt.Errorf("save knowledge config failed: %w", err))

View File

@ -1876,7 +1876,7 @@
return;
}
if (data && Number(data.code) === 0) {
showSuccess(t('success.saved_restart_required'));
showSuccess(t('success.saved'));
} else {
const msg = (data && (data.msg || data.message)) ? (data.msg || data.message) : t('error.save_failed');
showError(msg);
@ -2453,31 +2453,31 @@
const initial = {
embedding_type: embedding?.type ?? '',
embedding_max_batch_size: Number(embedding?.max_batch_size ?? 100),
ark_base_url: conn?.ark?.base_url || baseConn?.base_url || '',
ark_api_key: conn?.ark?.api_key || baseConn?.api_key || '',
ark_base_url: baseConn?.base_url || '',
ark_api_key: baseConn?.api_key || '',
ark_region: conn?.ark?.region || '',
ark_model: conn?.ark?.model || baseConn?.model || '',
ark_embedding_dims: conn?.ark?.embedding_dims ?? embeddingInfo?.dims ?? '',
ark_model: baseConn?.model || '',
ark_embedding_dims: embeddingInfo?.dims ?? '0',
ark_embedding_api_type: conn?.ark?.embedding_api_type || conn?.ark?.api_type || '',
openai_base_url: conn?.openai?.base_url || '',
openai_api_key: conn?.openai?.api_key || '',
openai_model: conn?.openai?.model || '',
openai_embedding_dims: conn?.openai?.embedding_dims ?? '',
openai_base_url: baseConn?.base_url || '',
openai_api_key: baseConn?.api_key || '',
openai_model: baseConn?.model || '',
openai_embedding_dims: embeddingInfo?.dims ?? '0',
openai_embedding_request_dims: conn?.openai?.embedding_request_dims ?? conn?.openai?.request_dims ?? '',
openai_embedding_by_azure: Boolean(conn?.openai?.embedding_by_azure ?? conn?.openai?.by_azure),
openai_embedding_api_version: conn?.openai?.embedding_api_version ?? conn?.openai?.api_version ?? '',
ollama_base_url: conn?.ollama?.base_url || '',
ollama_model: conn?.ollama?.model || '',
ollama_embedding_dims: conn?.ollama?.embedding_dims ?? '',
gemini_base_url: conn?.gemini?.base_url || '',
gemini_api_key: conn?.gemini?.api_key || '',
gemini_model: conn?.gemini?.model || '',
gemini_embedding_dims: conn?.gemini?.embedding_dims ?? '',
ollama_base_url: baseConn?.base_url || '',
ollama_model: baseConn?.model || '',
ollama_embedding_dims: embeddingInfo?.dims ?? '0',
gemini_base_url: baseConn?.base_url || '',
gemini_api_key: baseConn?.api_key || '',
gemini_model: baseConn?.model || '',
gemini_embedding_dims: embeddingInfo?.dims ?? '0',
gemini_embedding_backend: conn?.gemini?.embedding_backend ?? conn?.gemini?.backend ?? '',
gemini_embedding_project: conn?.gemini?.embedding_project ?? conn?.gemini?.project ?? '',
gemini_embedding_location: conn?.gemini?.embedding_location ?? conn?.gemini?.location ?? '',
http_address: conn?.http?.address || '',
http_dims: conn?.http?.dims ?? embeddingInfo?.dims ?? '',
http_dims: embeddingInfo?.dims ?? '0',
rerank_type: rerank?.type ?? '',
vik_ak: rerank?.vikingdb_config?.ak || '',
vik_sk: rerank?.vikingdb_config?.sk || '',
@ -2493,6 +2493,84 @@
builtin_model_id: builtinModelId ?? ''
};
if (initial.embedding_type==0) {
initial.openai_api_key ='';
initial.openai_base_url ='';
initial.openai_model ='';
initial.openai_embedding_dims ='';
initial.ollama_base_url ='';
initial.ollama_model ='';
initial.ollama_embedding_dims ='';
initial.gemini_base_url ='';
initial.gemini_api_key ='';
initial.gemini_model ='';
initial.gemini_embedding_dims ='';
initial.http_address ='';
initial.http_dims ='';
}else if (initial.embedding_type==1) {
initial.ark_base_url ='';
initial.ark_api_key ='';
initial.ark_region ='';
initial.ark_model ='';
initial.ark_embedding_dims ='';
initial.ollama_base_url ='';
initial.ollama_model ='';
initial.ollama_embedding_dims ='';
initial.gemini_base_url ='';
initial.gemini_api_key ='';
initial.gemini_model ='';
initial.gemini_embedding_dims ='';
initial.http_address ='';
initial.http_dims ='';
} else if (initial.embedding_type==2) {
initial.openai_api_key ='';
initial.openai_base_url ='';
initial.openai_model ='';
initial.openai_embedding_dims ='';
initial.ark_base_url ='';
initial.ark_api_key ='';
initial.ark_region ='';
initial.ark_model ='';
initial.ark_embedding_dims ='';
initial.gemini_base_url ='';
initial.gemini_api_key ='';
initial.gemini_model ='';
initial.http_address ='';
initial.http_dims ='';
} else if (initial.embedding_type==3) {
initial.openai_api_key ='';
initial.openai_base_url ='';
initial.openai_model ='';
initial.openai_embedding_dims ='';
initial.ark_base_url ='';
initial.ark_api_key ='';
initial.ark_region ='';
initial.ark_model ='';
initial.ark_embedding_dims ='';
initial.ollama_base_url ='';
initial.ollama_model ='';
initial.ollama_embedding_dims ='';
initial.http_address ='';
initial.http_dims ='';
} else if (initial.embedding_type==4) {
initial.openai_api_key ='';
initial.openai_base_url ='';
initial.openai_model ='';
initial.openai_embedding_dims ='';
initial.ark_base_url ='';
initial.ark_api_key ='';
initial.ark_region ='';
initial.ark_model ='';
initial.ark_embedding_dims ='';
initial.ollama_base_url ='';
initial.ollama_model ='';
initial.ollama_embedding_dims ='';
initial.gemini_base_url ='';
initial.gemini_api_key ='';
initial.gemini_model ='';
initial.gemini_embedding_dims ='';
}
const inputStyle = "width:100%;padding:8px;border:1px solid #e1e8ed;border-radius:6px;font-size:12px;";
const labelStyle = "font-size:12px;color:#374151;margin-bottom:6px;";
const groupGrid = "display:grid;grid-template-columns:repeat(auto-fit,minmax(200px,1fr));gap:12px;";
@ -2544,9 +2622,9 @@
<input id="kb_ark_model" type="text" style="${inputStyle}" value="${initial.ark_model}" />
</div>
<div>
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<div style="${labelStyle}">Dims</div>
<input id="kb_ark_embedding_dims" type="number" style="${inputStyle}" value="${initial.ark_embedding_dims}" />
</div>
</div>
<div>
<div style="${labelStyle}">API Type</div>
<select id="kb_ark_embedding_api_type" style="${inputStyle}">
@ -2574,9 +2652,9 @@
<input id="kb_openai_api_key" type="text" style="${inputStyle}" value="${initial.openai_api_key}" />
</div>
<div>
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<div style="${labelStyle}">Dims</div>
<input id="kb_openai_embedding_dims" type="number" style="${inputStyle}" value="${initial.openai_embedding_dims}" />
</div>
</div>
<!-- 选填字段靠后显示By Azure、API Version、Request Dims -->
<div>
<div style="${labelStyle}">By Azure</div>
@ -2605,9 +2683,9 @@
<input id="kb_ollama_model" type="text" style="${inputStyle}" value="${initial.ollama_model}" />
</div>
<div>
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<div style="${labelStyle}">Dims</div>
<input id="kb_ollama_embedding_dims" type="number" style="${inputStyle}" value="${initial.ollama_embedding_dims}" />
</div>
</div>
</div>
</div>
@ -2626,14 +2704,15 @@
<div style="${labelStyle}">Model<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<input id="kb_gemini_model" type="text" style="${inputStyle}" value="${initial.gemini_model}" />
</div>
<div>
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<input id="kb_gemini_embedding_dims" type="number" style="${inputStyle}" value="${initial.gemini_embedding_dims}" />
</div>
<div>
<div style="${labelStyle}">Backend<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<input id="kb_gemini_embedding_backend" type="number" style="${inputStyle}" value="${initial.gemini_embedding_backend}" />
</div>
<div>
<div style="${labelStyle}">Dims</div>
<input id="kb_gemini_embedding_dims" type="number" style="${inputStyle}" value="${initial.gemini_embedding_dims}" />
</div>
<div>
<div style="${labelStyle}">Project</div>
<input id="kb_gemini_embedding_project" type="text" style="${inputStyle}" value="${initial.gemini_embedding_project}" />
@ -2653,7 +2732,7 @@
<input id="kb_http_address" type="text" style="${inputStyle}" value="${initial.http_address}" />
</div>
<div>
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
<div style="${labelStyle}">Dims</div>
<input id="kb_http_dims" type="number" style="${inputStyle}" value="${initial.http_dims}" />
</div>
</div>
@ -3041,7 +3120,7 @@
if (!get('kb_ark_base_url')) missing.push('Base URL');
if (!get('kb_ark_api_key')) missing.push('API Key');
if (!get('kb_ark_model')) missing.push('Model');
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
if (missing.length > 0) {
showError(t('error.missing_required', { list: missing.join(', ') }));
return;
@ -3053,7 +3132,6 @@
if (!get('kb_openai_base_url')) missing.push('Base URL');
if (!get('kb_openai_model')) missing.push('Model');
if (!get('kb_openai_api_key')) missing.push('API Key');
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
if (missing.length > 0) {
showError(t('error.missing_required', { list: missing.join(', ') }));
return;
@ -3064,7 +3142,6 @@
const missing = [];
if (!get('kb_ollama_base_url')) missing.push('Base URL');
if (!get('kb_ollama_model')) missing.push('Model');
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
if (missing.length > 0) {
showError(t('error.missing_required', { list: missing.join(', ') }));
return;
@ -3077,7 +3154,6 @@
if (!get('kb_gemini_base_url')) missing.push('Base URL');
if (!get('kb_gemini_api_key')) missing.push('API Key');
if (!get('kb_gemini_model')) missing.push('Model');
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
if (!backendStr) missing.push('Backend');
if (missing.length > 0) {
showError(t('error.missing_required', { list: missing.join(', ') }));
@ -3088,7 +3164,6 @@
const dimsVal = Number(document.getElementById('kb_http_dims')?.value);
const missing = [];
if (!get('kb_http_address')) missing.push('Address');
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
if (missing.length > 0) {
showError(t('error.missing_required', { list: missing.join(', ') }));
return;
@ -3109,7 +3184,7 @@
}
} catch (_) {}
showOverlay(t('overlay.saving'));
// showOverlay(t('overlay.saving'));
fetch('/api/admin/config/knowledge/save', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },

View File

@ -21,11 +21,7 @@ import (
"os"
"time"
"github.com/cloudwego/eino-ext/components/embedding/gemini"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"google.golang.org/genai"
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
@ -34,9 +30,7 @@ import (
searchstoreOceanbase "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase"
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/embedding"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl"
"github.com/coze-dev/coze-studio/backend/infra/es/impl/es"
"github.com/coze-dev/coze-studio/backend/infra/oceanbase"
"github.com/coze-dev/coze-studio/backend/pkg/envkey"
@ -82,7 +76,7 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
}
emb, err := getEmbedding(ctx, conf.EmbeddingConfig)
emb, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
if err != nil {
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
}
@ -134,7 +128,7 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
BuiltinEmbedding: nil,
}
} else {
builtinEmbedding, err := getEmbedding(ctx, conf.EmbeddingConfig)
builtinEmbedding, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
if err != nil {
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
}
@ -159,7 +153,7 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
return mgr, nil
case "oceanbase":
emb, err := getEmbedding(ctx, conf.EmbeddingConfig)
emb, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
if err != nil {
return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err)
}
@ -213,103 +207,3 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
}
}
func getEmbedding(ctx context.Context, cfg *config.EmbeddingConfig) (embedding.Embedder, error) {
var (
emb embedding.Embedder
err error
connInfo = cfg.Connection.BaseConnInfo
embeddingInfo = cfg.Connection.EmbeddingInfo
)
switch cfg.Type {
case config.EmbeddingType_OpenAI:
openaiConnCfg := cfg.Connection.Openai
openAICfg := &openai.EmbeddingConfig{
APIKey: connInfo.APIKey,
BaseURL: connInfo.BaseURL,
Model: connInfo.Model,
ByAzure: openaiConnCfg.ByAzure,
APIVersion: openaiConnCfg.APIVersion,
}
if openaiConnCfg.RequestDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(openaiConnCfg.RequestDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case config.EmbeddingType_Ark:
arkCfg := cfg.Connection.Ark
apiType := ark.APITypeText
if ark.APIType(arkCfg.APIType) == ark.APITypeMultiModal {
apiType = ark.APITypeMultiModal
}
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: connInfo.APIKey,
Model: connInfo.Model,
BaseURL: connInfo.BaseURL,
APIType: &apiType,
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case config.EmbeddingType_Ollama:
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: connInfo.BaseURL,
Model: connInfo.Model,
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case config.EmbeddingType_Gemini:
geminiCfg := cfg.Connection.Gemini
if len(connInfo.Model) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
}
if len(connInfo.APIKey) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
}
geminiCli, err1 := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: connInfo.APIKey,
Backend: genai.Backend(geminiCfg.Backend),
Project: geminiCfg.Project,
Location: geminiCfg.Location,
HTTPOptions: genai.HTTPOptions{
BaseURL: connInfo.BaseURL,
},
})
if err1 != nil {
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
}
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
Client: geminiCli,
Model: connInfo.Model,
OutputDimensionality: ptr.Of(int32(embeddingInfo.Dims)),
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
}
case config.EmbeddingType_HTTP:
httpCfg := cfg.Connection.HTTP
emb, err = http.NewEmbedding(httpCfg.Address, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}

View File

@ -57,7 +57,7 @@ type embWrap struct {
embedding.Embedder
}
func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
func (d *embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
resp := make([][]float64, 0, len(texts))
for _, part := range slices.Chunks(texts, d.batchSize) {
partResult, err := d.Embedder.EmbedStrings(ctx, part, opts...)
@ -87,19 +87,29 @@ func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embed
return resp, nil
}
func (d embWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
func (d *embWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
return nil, nil, fmt.Errorf("[arkEmbedder] EmbedStringsHybrid not support")
}
func (d embWrap) Dimensions() int64 {
func (d *embWrap) Dimensions() int64 {
if d.dims == 0 {
embeddings, err := d.Embedder.EmbedStrings(context.Background(), []string{"test"})
if err != nil {
return 0
}
if len(embeddings) == 0 {
return 0
}
d.dims = int64(len(embeddings[0]))
}
return d.dims
}
func (d embWrap) SupportStatus() contract.SupportStatus {
func (d *embWrap) SupportStatus() contract.SupportStatus {
return contract.SupportDense
}
func (d embWrap) slicedNormL2(vectors [][]float64) ([][]float64, error) {
func (d *embWrap) slicedNormL2(vectors [][]float64) ([][]float64, error) {
if len(vectors) == 0 {
return vectors, nil
}

View File

@ -195,6 +195,16 @@ func (e *embedder) do(req *http.Request) (*embedResp, error) {
}
func (e *embedder) Dimensions() int64 {
if e.dim == 0 {
embeddings, err := e.EmbedStrings(context.Background(), []string{"test"})
if err != nil {
return 0
}
if len(embeddings) == 0 {
return 0
}
e.dim = int64(len(embeddings[0]))
}
return e.dim
}

View File

@ -0,0 +1,134 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package impl
import (
"context"
"fmt"
"github.com/cloudwego/eino-ext/components/embedding/gemini"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"google.golang.org/genai"
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
"github.com/coze-dev/coze-studio/backend/infra/embedding"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func GetEmbedding(ctx context.Context, cfg *config.EmbeddingConfig) (embedding.Embedder, error) {
var (
emb embedding.Embedder
err error
connInfo = cfg.Connection.BaseConnInfo
embeddingInfo = cfg.Connection.EmbeddingInfo
)
switch cfg.Type {
case config.EmbeddingType_OpenAI:
openaiConnCfg := cfg.Connection.Openai
openAICfg := &openai.EmbeddingConfig{
APIKey: connInfo.APIKey,
BaseURL: connInfo.BaseURL,
Model: connInfo.Model,
ByAzure: openaiConnCfg.ByAzure,
APIVersion: openaiConnCfg.APIVersion,
}
if openaiConnCfg.RequestDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(openaiConnCfg.RequestDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case config.EmbeddingType_Ark:
arkCfg := cfg.Connection.Ark
apiType := ark.APITypeText
if ark.APIType(arkCfg.APIType) == ark.APITypeMultiModal {
apiType = ark.APITypeMultiModal
}
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: connInfo.APIKey,
Model: connInfo.Model,
BaseURL: connInfo.BaseURL,
APIType: &apiType,
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case config.EmbeddingType_Ollama:
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: connInfo.BaseURL,
Model: connInfo.Model,
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case config.EmbeddingType_Gemini:
geminiCfg := cfg.Connection.Gemini
if len(connInfo.Model) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
}
if len(connInfo.APIKey) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
}
geminiCli, err1 := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: connInfo.APIKey,
Backend: genai.Backend(geminiCfg.Backend),
Project: geminiCfg.Project,
Location: geminiCfg.Location,
HTTPOptions: genai.HTTPOptions{
BaseURL: connInfo.BaseURL,
},
})
if err1 != nil {
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
}
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
Client: geminiCli,
Model: connInfo.Model,
OutputDimensionality: ptr.Of(int32(embeddingInfo.Dims)),
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
}
case config.EmbeddingType_HTTP:
httpCfg := cfg.Connection.HTTP
emb, err = http.NewEmbedding(httpCfg.Address, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}

View File

@ -32,7 +32,7 @@ type denseOnlyWrap struct {
embedding.Embedder
}
func (d denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
func (d *denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
resp := make([][]float64, 0, len(texts))
for _, part := range slices.Chunks(texts, d.batchSize) {
partResult, err := d.Embedder.EmbedStrings(ctx, part, opts...)
@ -44,14 +44,24 @@ func (d denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ..
return resp, nil
}
func (d denseOnlyWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
func (d *denseOnlyWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
return nil, nil, fmt.Errorf("[denseOnlyWrap] EmbedStringsHybrid not support")
}
func (d denseOnlyWrap) Dimensions() int64 {
func (d *denseOnlyWrap) Dimensions() int64 {
if d.dims == 0 {
embeddings, err := d.Embedder.EmbedStrings(context.Background(), []string{"test"})
if err != nil {
return 0
}
if len(embeddings) == 0 {
return 0
}
d.dims = int64(len(embeddings[0]))
}
return d.dims
}
func (d denseOnlyWrap) SupportStatus() contract.SupportStatus {
func (d *denseOnlyWrap) SupportStatus() contract.SupportStatus {
return contract.SupportDense
}