feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@ -2,7 +2,7 @@ import json
from typing import Any
from pydantic import BaseModel
from volcengine.viking_db import (
from volcengine.viking_db import ( # type: ignore
Data,
DistanceType,
Field,
@ -121,11 +121,12 @@ class VikingDBVector(BaseVector):
for i, page_content in enumerate(page_contents):
metadata = {}
if metadatas is not None:
for key, val in metadatas[i].items():
for key, val in (metadatas[i] or {}).items():
metadata[key] = val
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
@ -178,7 +179,7 @@ class VikingDBVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
def _get_search_res(self, results, score_threshold):
def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:
return []
@ -191,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: