mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
refactor(api): tighten types in trivial lint and config fixes (#34773)
Co-authored-by: tmimmanuel <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> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
return []
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
@ -123,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
def _get_cursor(self) -> Iterator[Any]:
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
@ -106,14 +106,15 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
results: QueryResult
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
results = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
@ -165,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST or "",
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
|
||||
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
||||
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
||||
),
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@ -180,7 +179,7 @@ class QdrantVector(BaseVector):
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
@ -472,7 +471,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client._load()
|
||||
self._client._load() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
|
||||
@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
Base: Any = declarative_base()
|
||||
|
||||
|
||||
class RelytConfig(BaseModel):
|
||||
|
||||
@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
|
||||
FileType,
|
||||
detect_filetype,
|
||||
)
|
||||
|
||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
|
||||
@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation():
|
||||
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
|
||||
vector.delete()
|
||||
|
||||
runner._create_collection_if_not_exists.assert_called_once_with(2)
|
||||
runner.create_collection_if_not_exists.assert_called_once_with(2)
|
||||
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
|
||||
runner.delete_by_ids.assert_called_once_with(["d1"])
|
||||
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")
|
||||
|
||||
@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
vector._client = MagicMock()
|
||||
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=1024)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=1024)
|
||||
|
||||
vector._client.create_collection.assert_called_once()
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=1024)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=1024)
|
||||
|
||||
vector._client.describe_collection.assert_not_called()
|
||||
vector._client.create_collection.assert_not_called()
|
||||
@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
|
||||
|
||||
with pytest.raises(ValueError, match="failed to create collection collection_1"):
|
||||
vector._create_collection_if_not_exists(embedding_dimension=512)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=512)
|
||||
|
||||
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
|
||||
@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=3)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=3)
|
||||
|
||||
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
with pytest.raises(RuntimeError, match="permission denied"):
|
||||
vector._create_collection_if_not_exists(embedding_dimension=3)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=3)
|
||||
|
||||
|
||||
def test_delete_methods_raise_when_error_is_not_missing_table():
|
||||
|
||||
Reference in New Issue
Block a user