feat: optimize embedding configuration to support omitting model dims (#2412)
This commit is contained in:
@ -18,6 +18,7 @@ import (
|
||||
bizConf "github.com/coze-dev/coze-studio/backend/bizpkg/config"
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/config/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/bizpkg/llm/modelbuilder"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@ -131,6 +132,41 @@ func UpdateKnowledgeConfig(ctx context.Context, c *app.RequestContext) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := impl.GetEmbedding(ctx, req.KnowledgeConfig.EmbeddingConfig)
|
||||
if err != nil {
|
||||
invalidParamRequestResponse(c, fmt.Sprintf("get embedding failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims == 0 {
|
||||
req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims = int32(embedding.Dimensions())
|
||||
|
||||
embedding, err = impl.GetEmbedding(ctx, req.KnowledgeConfig.EmbeddingConfig)
|
||||
if err != nil {
|
||||
invalidParamRequestResponse(c, fmt.Sprintf("get embedding failed: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
denseEmbeddings, err := embedding.EmbedStrings(ctx, []string{"test"})
|
||||
if err != nil {
|
||||
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(denseEmbeddings) == 0 {
|
||||
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
logs.CtxDebugf(ctx, "embed test string result: %d, expect %d",
|
||||
len(denseEmbeddings[0]), req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims)
|
||||
if len(denseEmbeddings[0]) != int(req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims) {
|
||||
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: dims not match, expect %d, got %d",
|
||||
req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims, len(denseEmbeddings[0])))
|
||||
return
|
||||
}
|
||||
|
||||
err = bizConf.Knowledge().SaveKnowledgeConfig(ctx, req.KnowledgeConfig)
|
||||
if err != nil {
|
||||
internalServerErrorResponse(ctx, c, fmt.Errorf("save knowledge config failed: %w", err))
|
||||
|
||||
@ -1876,7 +1876,7 @@
|
||||
return;
|
||||
}
|
||||
if (data && Number(data.code) === 0) {
|
||||
showSuccess(t('success.saved_restart_required'));
|
||||
showSuccess(t('success.saved'));
|
||||
} else {
|
||||
const msg = (data && (data.msg || data.message)) ? (data.msg || data.message) : t('error.save_failed');
|
||||
showError(msg);
|
||||
@ -2453,31 +2453,31 @@
|
||||
const initial = {
|
||||
embedding_type: embedding?.type ?? '',
|
||||
embedding_max_batch_size: Number(embedding?.max_batch_size ?? 100),
|
||||
ark_base_url: conn?.ark?.base_url || baseConn?.base_url || '',
|
||||
ark_api_key: conn?.ark?.api_key || baseConn?.api_key || '',
|
||||
ark_base_url: baseConn?.base_url || '',
|
||||
ark_api_key: baseConn?.api_key || '',
|
||||
ark_region: conn?.ark?.region || '',
|
||||
ark_model: conn?.ark?.model || baseConn?.model || '',
|
||||
ark_embedding_dims: conn?.ark?.embedding_dims ?? embeddingInfo?.dims ?? '',
|
||||
ark_model: baseConn?.model || '',
|
||||
ark_embedding_dims: embeddingInfo?.dims ?? '0',
|
||||
ark_embedding_api_type: conn?.ark?.embedding_api_type || conn?.ark?.api_type || '',
|
||||
openai_base_url: conn?.openai?.base_url || '',
|
||||
openai_api_key: conn?.openai?.api_key || '',
|
||||
openai_model: conn?.openai?.model || '',
|
||||
openai_embedding_dims: conn?.openai?.embedding_dims ?? '',
|
||||
openai_base_url: baseConn?.base_url || '',
|
||||
openai_api_key: baseConn?.api_key || '',
|
||||
openai_model: baseConn?.model || '',
|
||||
openai_embedding_dims: embeddingInfo?.dims ?? '0',
|
||||
openai_embedding_request_dims: conn?.openai?.embedding_request_dims ?? conn?.openai?.request_dims ?? '',
|
||||
openai_embedding_by_azure: Boolean(conn?.openai?.embedding_by_azure ?? conn?.openai?.by_azure),
|
||||
openai_embedding_api_version: conn?.openai?.embedding_api_version ?? conn?.openai?.api_version ?? '',
|
||||
ollama_base_url: conn?.ollama?.base_url || '',
|
||||
ollama_model: conn?.ollama?.model || '',
|
||||
ollama_embedding_dims: conn?.ollama?.embedding_dims ?? '',
|
||||
gemini_base_url: conn?.gemini?.base_url || '',
|
||||
gemini_api_key: conn?.gemini?.api_key || '',
|
||||
gemini_model: conn?.gemini?.model || '',
|
||||
gemini_embedding_dims: conn?.gemini?.embedding_dims ?? '',
|
||||
ollama_base_url: baseConn?.base_url || '',
|
||||
ollama_model: baseConn?.model || '',
|
||||
ollama_embedding_dims: embeddingInfo?.dims ?? '0',
|
||||
gemini_base_url: baseConn?.base_url || '',
|
||||
gemini_api_key: baseConn?.api_key || '',
|
||||
gemini_model: baseConn?.model || '',
|
||||
gemini_embedding_dims: embeddingInfo?.dims ?? '0',
|
||||
gemini_embedding_backend: conn?.gemini?.embedding_backend ?? conn?.gemini?.backend ?? '',
|
||||
gemini_embedding_project: conn?.gemini?.embedding_project ?? conn?.gemini?.project ?? '',
|
||||
gemini_embedding_location: conn?.gemini?.embedding_location ?? conn?.gemini?.location ?? '',
|
||||
http_address: conn?.http?.address || '',
|
||||
http_dims: conn?.http?.dims ?? embeddingInfo?.dims ?? '',
|
||||
http_dims: embeddingInfo?.dims ?? '0',
|
||||
rerank_type: rerank?.type ?? '',
|
||||
vik_ak: rerank?.vikingdb_config?.ak || '',
|
||||
vik_sk: rerank?.vikingdb_config?.sk || '',
|
||||
@ -2493,6 +2493,84 @@
|
||||
builtin_model_id: builtinModelId ?? ''
|
||||
};
|
||||
|
||||
if (initial.embedding_type==0) {
|
||||
initial.openai_api_key ='';
|
||||
initial.openai_base_url ='';
|
||||
initial.openai_model ='';
|
||||
initial.openai_embedding_dims ='';
|
||||
initial.ollama_base_url ='';
|
||||
initial.ollama_model ='';
|
||||
initial.ollama_embedding_dims ='';
|
||||
initial.gemini_base_url ='';
|
||||
initial.gemini_api_key ='';
|
||||
initial.gemini_model ='';
|
||||
initial.gemini_embedding_dims ='';
|
||||
initial.http_address ='';
|
||||
initial.http_dims ='';
|
||||
}else if (initial.embedding_type==1) {
|
||||
initial.ark_base_url ='';
|
||||
initial.ark_api_key ='';
|
||||
initial.ark_region ='';
|
||||
initial.ark_model ='';
|
||||
initial.ark_embedding_dims ='';
|
||||
initial.ollama_base_url ='';
|
||||
initial.ollama_model ='';
|
||||
initial.ollama_embedding_dims ='';
|
||||
initial.gemini_base_url ='';
|
||||
initial.gemini_api_key ='';
|
||||
initial.gemini_model ='';
|
||||
initial.gemini_embedding_dims ='';
|
||||
initial.http_address ='';
|
||||
initial.http_dims ='';
|
||||
} else if (initial.embedding_type==2) {
|
||||
initial.openai_api_key ='';
|
||||
initial.openai_base_url ='';
|
||||
initial.openai_model ='';
|
||||
initial.openai_embedding_dims ='';
|
||||
initial.ark_base_url ='';
|
||||
initial.ark_api_key ='';
|
||||
initial.ark_region ='';
|
||||
initial.ark_model ='';
|
||||
initial.ark_embedding_dims ='';
|
||||
initial.gemini_base_url ='';
|
||||
initial.gemini_api_key ='';
|
||||
initial.gemini_model ='';
|
||||
initial.http_address ='';
|
||||
initial.http_dims ='';
|
||||
} else if (initial.embedding_type==3) {
|
||||
initial.openai_api_key ='';
|
||||
initial.openai_base_url ='';
|
||||
initial.openai_model ='';
|
||||
initial.openai_embedding_dims ='';
|
||||
initial.ark_base_url ='';
|
||||
initial.ark_api_key ='';
|
||||
initial.ark_region ='';
|
||||
initial.ark_model ='';
|
||||
initial.ark_embedding_dims ='';
|
||||
initial.ollama_base_url ='';
|
||||
initial.ollama_model ='';
|
||||
initial.ollama_embedding_dims ='';
|
||||
initial.http_address ='';
|
||||
initial.http_dims ='';
|
||||
} else if (initial.embedding_type==4) {
|
||||
initial.openai_api_key ='';
|
||||
initial.openai_base_url ='';
|
||||
initial.openai_model ='';
|
||||
initial.openai_embedding_dims ='';
|
||||
initial.ark_base_url ='';
|
||||
initial.ark_api_key ='';
|
||||
initial.ark_region ='';
|
||||
initial.ark_model ='';
|
||||
initial.ark_embedding_dims ='';
|
||||
initial.ollama_base_url ='';
|
||||
initial.ollama_model ='';
|
||||
initial.ollama_embedding_dims ='';
|
||||
initial.gemini_base_url ='';
|
||||
initial.gemini_api_key ='';
|
||||
initial.gemini_model ='';
|
||||
initial.gemini_embedding_dims ='';
|
||||
}
|
||||
|
||||
const inputStyle = "width:100%;padding:8px;border:1px solid #e1e8ed;border-radius:6px;font-size:12px;";
|
||||
const labelStyle = "font-size:12px;color:#374151;margin-bottom:6px;";
|
||||
const groupGrid = "display:grid;grid-template-columns:repeat(auto-fit,minmax(200px,1fr));gap:12px;";
|
||||
@ -2544,9 +2622,9 @@
|
||||
<input id="kb_ark_model" type="text" style="${inputStyle}" value="${initial.ark_model}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<div style="${labelStyle}">Dims</div>
|
||||
<input id="kb_ark_embedding_dims" type="number" style="${inputStyle}" value="${initial.ark_embedding_dims}" />
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">API Type</div>
|
||||
<select id="kb_ark_embedding_api_type" style="${inputStyle}">
|
||||
@ -2574,9 +2652,9 @@
|
||||
<input id="kb_openai_api_key" type="text" style="${inputStyle}" value="${initial.openai_api_key}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<div style="${labelStyle}">Dims</div>
|
||||
<input id="kb_openai_embedding_dims" type="number" style="${inputStyle}" value="${initial.openai_embedding_dims}" />
|
||||
</div>
|
||||
</div>
|
||||
<!-- 选填字段靠后显示:By Azure、API Version、Request Dims -->
|
||||
<div>
|
||||
<div style="${labelStyle}">By Azure</div>
|
||||
@ -2605,9 +2683,9 @@
|
||||
<input id="kb_ollama_model" type="text" style="${inputStyle}" value="${initial.ollama_model}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<div style="${labelStyle}">Dims</div>
|
||||
<input id="kb_ollama_embedding_dims" type="number" style="${inputStyle}" value="${initial.ollama_embedding_dims}" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -2626,14 +2704,15 @@
|
||||
<div style="${labelStyle}">Model<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<input id="kb_gemini_model" type="text" style="${inputStyle}" value="${initial.gemini_model}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<input id="kb_gemini_embedding_dims" type="number" style="${inputStyle}" value="${initial.gemini_embedding_dims}" />
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div style="${labelStyle}">Backend<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<input id="kb_gemini_embedding_backend" type="number" style="${inputStyle}" value="${initial.gemini_embedding_backend}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Dims</div>
|
||||
<input id="kb_gemini_embedding_dims" type="number" style="${inputStyle}" value="${initial.gemini_embedding_dims}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Project</div>
|
||||
<input id="kb_gemini_embedding_project" type="text" style="${inputStyle}" value="${initial.gemini_embedding_project}" />
|
||||
@ -2653,7 +2732,7 @@
|
||||
<input id="kb_http_address" type="text" style="${inputStyle}" value="${initial.http_address}" />
|
||||
</div>
|
||||
<div>
|
||||
<div style="${labelStyle}">Dims<span style="display:inline-block;width:6px;height:6px;background:#e74c3c;border-radius:50%;margin-left:6px;"></span></div>
|
||||
<div style="${labelStyle}">Dims</div>
|
||||
<input id="kb_http_dims" type="number" style="${inputStyle}" value="${initial.http_dims}" />
|
||||
</div>
|
||||
</div>
|
||||
@ -3041,7 +3120,7 @@
|
||||
if (!get('kb_ark_base_url')) missing.push('Base URL');
|
||||
if (!get('kb_ark_api_key')) missing.push('API Key');
|
||||
if (!get('kb_ark_model')) missing.push('Model');
|
||||
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
|
||||
|
||||
if (missing.length > 0) {
|
||||
showError(t('error.missing_required', { list: missing.join(', ') }));
|
||||
return;
|
||||
@ -3053,7 +3132,6 @@
|
||||
if (!get('kb_openai_base_url')) missing.push('Base URL');
|
||||
if (!get('kb_openai_model')) missing.push('Model');
|
||||
if (!get('kb_openai_api_key')) missing.push('API Key');
|
||||
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
|
||||
if (missing.length > 0) {
|
||||
showError(t('error.missing_required', { list: missing.join(', ') }));
|
||||
return;
|
||||
@ -3064,7 +3142,6 @@
|
||||
const missing = [];
|
||||
if (!get('kb_ollama_base_url')) missing.push('Base URL');
|
||||
if (!get('kb_ollama_model')) missing.push('Model');
|
||||
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
|
||||
if (missing.length > 0) {
|
||||
showError(t('error.missing_required', { list: missing.join(', ') }));
|
||||
return;
|
||||
@ -3077,7 +3154,6 @@
|
||||
if (!get('kb_gemini_base_url')) missing.push('Base URL');
|
||||
if (!get('kb_gemini_api_key')) missing.push('API Key');
|
||||
if (!get('kb_gemini_model')) missing.push('Model');
|
||||
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
|
||||
if (!backendStr) missing.push('Backend');
|
||||
if (missing.length > 0) {
|
||||
showError(t('error.missing_required', { list: missing.join(', ') }));
|
||||
@ -3088,7 +3164,6 @@
|
||||
const dimsVal = Number(document.getElementById('kb_http_dims')?.value);
|
||||
const missing = [];
|
||||
if (!get('kb_http_address')) missing.push('Address');
|
||||
if (!dimsVal || isNaN(dimsVal) || dimsVal <= 0) missing.push('Dims');
|
||||
if (missing.length > 0) {
|
||||
showError(t('error.missing_required', { list: missing.join(', ') }));
|
||||
return;
|
||||
@ -3109,7 +3184,7 @@
|
||||
}
|
||||
} catch (_) {}
|
||||
|
||||
showOverlay(t('overlay.saving'));
|
||||
// showOverlay(t('overlay.saving'));
|
||||
fetch('/api/admin/config/knowledge/save', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
|
||||
@ -21,11 +21,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
||||
@ -34,9 +30,7 @@ import (
|
||||
searchstoreOceanbase "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/es/impl/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/oceanbase"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/envkey"
|
||||
@ -82,7 +76,7 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
|
||||
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err := getEmbedding(ctx, conf.EmbeddingConfig)
|
||||
emb, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
||||
}
|
||||
@ -134,7 +128,7 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
|
||||
BuiltinEmbedding: nil,
|
||||
}
|
||||
} else {
|
||||
builtinEmbedding, err := getEmbedding(ctx, conf.EmbeddingConfig)
|
||||
builtinEmbedding, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
||||
}
|
||||
@ -159,7 +153,7 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
|
||||
return mgr, nil
|
||||
|
||||
case "oceanbase":
|
||||
emb, err := getEmbedding(ctx, conf.EmbeddingConfig)
|
||||
emb, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err)
|
||||
}
|
||||
@ -213,103 +207,3 @@ func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchst
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbedding(ctx context.Context, cfg *config.EmbeddingConfig) (embedding.Embedder, error) {
|
||||
var (
|
||||
emb embedding.Embedder
|
||||
err error
|
||||
connInfo = cfg.Connection.BaseConnInfo
|
||||
embeddingInfo = cfg.Connection.EmbeddingInfo
|
||||
)
|
||||
|
||||
switch cfg.Type {
|
||||
case config.EmbeddingType_OpenAI:
|
||||
openaiConnCfg := cfg.Connection.Openai
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: connInfo.APIKey,
|
||||
BaseURL: connInfo.BaseURL,
|
||||
Model: connInfo.Model,
|
||||
ByAzure: openaiConnCfg.ByAzure,
|
||||
APIVersion: openaiConnCfg.APIVersion,
|
||||
}
|
||||
|
||||
if openaiConnCfg.RequestDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(openaiConnCfg.RequestDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
case config.EmbeddingType_Ark:
|
||||
arkCfg := cfg.Connection.Ark
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if ark.APIType(arkCfg.APIType) == ark.APITypeMultiModal {
|
||||
apiType = ark.APITypeMultiModal
|
||||
}
|
||||
|
||||
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: connInfo.APIKey,
|
||||
Model: connInfo.Model,
|
||||
BaseURL: connInfo.BaseURL,
|
||||
APIType: &apiType,
|
||||
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
|
||||
case config.EmbeddingType_Ollama:
|
||||
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: connInfo.BaseURL,
|
||||
Model: connInfo.Model,
|
||||
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
case config.EmbeddingType_Gemini:
|
||||
geminiCfg := cfg.Connection.Gemini
|
||||
|
||||
if len(connInfo.Model) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
||||
}
|
||||
if len(connInfo.APIKey) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
||||
}
|
||||
|
||||
geminiCli, err1 := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: connInfo.APIKey,
|
||||
Backend: genai.Backend(geminiCfg.Backend),
|
||||
Project: geminiCfg.Project,
|
||||
Location: geminiCfg.Location,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: connInfo.BaseURL,
|
||||
},
|
||||
})
|
||||
if err1 != nil {
|
||||
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
||||
Client: geminiCli,
|
||||
Model: connInfo.Model,
|
||||
OutputDimensionality: ptr.Of(int32(embeddingInfo.Dims)),
|
||||
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
||||
}
|
||||
case config.EmbeddingType_HTTP:
|
||||
httpCfg := cfg.Connection.HTTP
|
||||
|
||||
emb, err = http.NewEmbedding(httpCfg.Address, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ type embWrap struct {
|
||||
embedding.Embedder
|
||||
}
|
||||
|
||||
func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
func (d *embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
resp := make([][]float64, 0, len(texts))
|
||||
for _, part := range slices.Chunks(texts, d.batchSize) {
|
||||
partResult, err := d.Embedder.EmbedStrings(ctx, part, opts...)
|
||||
@ -87,19 +87,29 @@ func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embed
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (d embWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
|
||||
func (d *embWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
|
||||
return nil, nil, fmt.Errorf("[arkEmbedder] EmbedStringsHybrid not support")
|
||||
}
|
||||
|
||||
func (d embWrap) Dimensions() int64 {
|
||||
func (d *embWrap) Dimensions() int64 {
|
||||
if d.dims == 0 {
|
||||
embeddings, err := d.Embedder.EmbedStrings(context.Background(), []string{"test"})
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return 0
|
||||
}
|
||||
d.dims = int64(len(embeddings[0]))
|
||||
}
|
||||
return d.dims
|
||||
}
|
||||
|
||||
func (d embWrap) SupportStatus() contract.SupportStatus {
|
||||
func (d *embWrap) SupportStatus() contract.SupportStatus {
|
||||
return contract.SupportDense
|
||||
}
|
||||
|
||||
func (d embWrap) slicedNormL2(vectors [][]float64) ([][]float64, error) {
|
||||
func (d *embWrap) slicedNormL2(vectors [][]float64) ([][]float64, error) {
|
||||
if len(vectors) == 0 {
|
||||
return vectors, nil
|
||||
}
|
||||
|
||||
@ -195,6 +195,16 @@ func (e *embedder) do(req *http.Request) (*embedResp, error) {
|
||||
}
|
||||
|
||||
func (e *embedder) Dimensions() int64 {
|
||||
if e.dim == 0 {
|
||||
embeddings, err := e.EmbedStrings(context.Background(), []string{"test"})
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return 0
|
||||
}
|
||||
e.dim = int64(len(embeddings[0]))
|
||||
}
|
||||
return e.dim
|
||||
}
|
||||
|
||||
|
||||
134
backend/infra/embedding/impl/impl.go
Normal file
134
backend/infra/embedding/impl/impl.go
Normal file
@ -0,0 +1,134 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/ark"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/http"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func GetEmbedding(ctx context.Context, cfg *config.EmbeddingConfig) (embedding.Embedder, error) {
|
||||
var (
|
||||
emb embedding.Embedder
|
||||
err error
|
||||
connInfo = cfg.Connection.BaseConnInfo
|
||||
embeddingInfo = cfg.Connection.EmbeddingInfo
|
||||
)
|
||||
|
||||
switch cfg.Type {
|
||||
case config.EmbeddingType_OpenAI:
|
||||
openaiConnCfg := cfg.Connection.Openai
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: connInfo.APIKey,
|
||||
BaseURL: connInfo.BaseURL,
|
||||
Model: connInfo.Model,
|
||||
ByAzure: openaiConnCfg.ByAzure,
|
||||
APIVersion: openaiConnCfg.APIVersion,
|
||||
}
|
||||
|
||||
if openaiConnCfg.RequestDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(openaiConnCfg.RequestDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
case config.EmbeddingType_Ark:
|
||||
arkCfg := cfg.Connection.Ark
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if ark.APIType(arkCfg.APIType) == ark.APITypeMultiModal {
|
||||
apiType = ark.APITypeMultiModal
|
||||
}
|
||||
|
||||
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: connInfo.APIKey,
|
||||
Model: connInfo.Model,
|
||||
BaseURL: connInfo.BaseURL,
|
||||
APIType: &apiType,
|
||||
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
|
||||
case config.EmbeddingType_Ollama:
|
||||
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: connInfo.BaseURL,
|
||||
Model: connInfo.Model,
|
||||
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
case config.EmbeddingType_Gemini:
|
||||
geminiCfg := cfg.Connection.Gemini
|
||||
|
||||
if len(connInfo.Model) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
||||
}
|
||||
if len(connInfo.APIKey) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
||||
}
|
||||
|
||||
geminiCli, err1 := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: connInfo.APIKey,
|
||||
Backend: genai.Backend(geminiCfg.Backend),
|
||||
Project: geminiCfg.Project,
|
||||
Location: geminiCfg.Location,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: connInfo.BaseURL,
|
||||
},
|
||||
})
|
||||
if err1 != nil {
|
||||
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
||||
Client: geminiCli,
|
||||
Model: connInfo.Model,
|
||||
OutputDimensionality: ptr.Of(int32(embeddingInfo.Dims)),
|
||||
}, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
||||
}
|
||||
case config.EmbeddingType_HTTP:
|
||||
httpCfg := cfg.Connection.HTTP
|
||||
|
||||
emb, err = http.NewEmbedding(httpCfg.Address, int64(embeddingInfo.Dims), int(cfg.MaxBatchSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
@ -32,7 +32,7 @@ type denseOnlyWrap struct {
|
||||
embedding.Embedder
|
||||
}
|
||||
|
||||
func (d denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
func (d *denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
resp := make([][]float64, 0, len(texts))
|
||||
for _, part := range slices.Chunks(texts, d.batchSize) {
|
||||
partResult, err := d.Embedder.EmbedStrings(ctx, part, opts...)
|
||||
@ -44,14 +44,24 @@ func (d denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ..
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (d denseOnlyWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
|
||||
func (d *denseOnlyWrap) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) {
|
||||
return nil, nil, fmt.Errorf("[denseOnlyWrap] EmbedStringsHybrid not support")
|
||||
}
|
||||
|
||||
func (d denseOnlyWrap) Dimensions() int64 {
|
||||
func (d *denseOnlyWrap) Dimensions() int64 {
|
||||
if d.dims == 0 {
|
||||
embeddings, err := d.Embedder.EmbedStrings(context.Background(), []string{"test"})
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return 0
|
||||
}
|
||||
d.dims = int64(len(embeddings[0]))
|
||||
}
|
||||
return d.dims
|
||||
}
|
||||
|
||||
func (d denseOnlyWrap) SupportStatus() contract.SupportStatus {
|
||||
func (d *denseOnlyWrap) SupportStatus() contract.SupportStatus {
|
||||
return contract.SupportDense
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user