mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-04 09:17:48 +08:00
Refact: switch from google-generativeai to google-genai (#13140)
### What problem does this PR solve? Refact: switch from oogle-generativeai to google-genai #13132 Refact: commnet out unused pywencai. ### Type of change - [x] Refactoring
This commit is contained in:
@ -20,7 +20,6 @@ from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import numpy as np
|
||||
import requests
|
||||
from ollama import Client
|
||||
@ -543,31 +542,87 @@ class BedrockEmbed(Base):
|
||||
class GeminiEmbed(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
||||
def __init__(self, key, model_name="gemini-embedding-001", **kwargs):
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
self.key = key
|
||||
self.model_name = "models/" + model_name
|
||||
self.model_name = model_name[7:] if model_name.startswith("models/") else model_name
|
||||
self.client = genai.Client(api_key=self.key)
|
||||
self.types = types
|
||||
|
||||
@staticmethod
|
||||
def _parse_embedding_vector(embedding):
|
||||
if isinstance(embedding, dict):
|
||||
values = embedding.get("values")
|
||||
if values is None:
|
||||
values = embedding.get("embedding")
|
||||
if values is not None:
|
||||
return values
|
||||
|
||||
values = getattr(embedding, "values", None)
|
||||
if values is None:
|
||||
values = getattr(embedding, "embedding", None)
|
||||
if values is not None:
|
||||
return values
|
||||
|
||||
raise TypeError(f"Unsupported embedding payload: {type(embedding)}")
|
||||
|
||||
@classmethod
|
||||
def _parse_embedding_response(cls, response):
|
||||
if response is None:
|
||||
raise ValueError("Embedding response is empty")
|
||||
|
||||
embeddings = getattr(response, "embeddings", None)
|
||||
if embeddings is None and isinstance(response, dict):
|
||||
embeddings = response.get("embeddings")
|
||||
|
||||
if embeddings is None:
|
||||
return [cls._parse_embedding_vector(response)]
|
||||
|
||||
return [cls._parse_embedding_vector(item) for item in embeddings]
|
||||
|
||||
def _build_embedding_config(self):
|
||||
task_type = "RETRIEVAL_DOCUMENT"
|
||||
if hasattr(self.types, "TaskType"):
|
||||
task_type = getattr(self.types.TaskType, "RETRIEVAL_DOCUMENT", task_type)
|
||||
try:
|
||||
return self.types.EmbedContentConfig(task_type=task_type, title="Embedding of single string")
|
||||
except TypeError:
|
||||
# Compatible with SDK versions that do not accept title in embed config.
|
||||
return self.types.EmbedContentConfig(task_type=task_type)
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = sum(num_tokens_from_string(text) for text in texts)
|
||||
genai.configure(api_key=self.key)
|
||||
config = self._build_embedding_config()
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
||||
result = None
|
||||
try:
|
||||
ress.extend(result["embedding"])
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name,
|
||||
contents=texts[i : i + batch_size],
|
||||
config=config,
|
||||
)
|
||||
ress.extend(self._parse_embedding_response(result))
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
raise Exception(f"Error: {result}")
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
genai.configure(api_key=self.key)
|
||||
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
||||
config = self._build_embedding_config()
|
||||
result = None
|
||||
token_count = num_tokens_from_string(text)
|
||||
try:
|
||||
return np.array(result["embedding"]), token_count
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name,
|
||||
contents=[truncate(text, 2048)],
|
||||
config=config,
|
||||
)
|
||||
return np.array(self._parse_embedding_response(result)[0]), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
raise Exception(f"Error: {result}")
|
||||
|
||||
Reference in New Issue
Block a user