refactor(api): tighten types in trivial lint and config fixes (#34773)

Co-authored-by: tmimmanuel <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>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
tmimmanuel
2026-04-09 01:14:44 +02:00
committed by GitHub
parent 4c70bfa8b8
commit 1a4eb47e1d
10 changed files with 28 additions and 23 deletions

View File

@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
self.analyticdb_vector.add_texts(documents, embeddings)
return []
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)

View File

@ -123,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
def create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException

View File

@ -1,5 +1,6 @@
import json
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
)
@contextmanager
def _get_cursor(self):
def _get_cursor(self) -> Iterator[Any]:
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
def create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):

View File

@ -2,7 +2,7 @@ import json
from typing import Any, TypedDict
import chromadb
from chromadb import QueryResult, Settings
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
from pydantic import BaseModel
from configs import dify_config
@ -106,14 +106,15 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
document_ids_filter = kwargs.get("document_ids_filter")
results: QueryResult
if document_ids_filter:
results: QueryResult = collection.query(
results = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
)
else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data
@ -165,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
config=ChromaConfig(
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
),

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
import qdrant_client
from flask import current_app
@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
@ -180,7 +179,7 @@ class QdrantVector(BaseVector):
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=points)
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
added_ids.extend(batch_ids)
return added_ids
@ -472,7 +471,7 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load()
self._client._load() # pyright: ignore[reportPrivateUsage]
@classmethod
def _document_from_scored_point(

View File

@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
Base: Any = declarative_base()
class RelytConfig(BaseModel):

View File

@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype
from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
FileType,
detect_filetype,
)
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
# check the file extension
try:
import magic # noqa: F401
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
is_doc = detect_filetype(self._file_path) == FileType.DOC
except ImportError:

View File

@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation():
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
vector.delete()
runner._create_collection_if_not_exists.assert_called_once_with(2)
runner.create_collection_if_not_exists.assert_called_once_with(2)
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
runner.delete_by_ids.assert_called_once_with(["d1"])
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")

View File

@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
vector._client = MagicMock()
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector.create_collection_if_not_exists(embedding_dimension=1024)
vector._client.create_collection.assert_called_once()
openapi_module.redis_client.set.assert_called_once()
@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
vector.config = _config()
vector._client = MagicMock()
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector.create_collection_if_not_exists(embedding_dimension=1024)
vector._client.describe_collection.assert_not_called()
vector._client.create_collection.assert_not_called()
@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
with pytest.raises(ValueError, match="failed to create collection collection_1"):
vector._create_collection_if_not_exists(embedding_dimension=512)
vector.create_collection_if_not_exists(embedding_dimension=512)
def test_openapi_add_delete_and_search_methods(monkeypatch):

View File

@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
vector._get_cursor = _cursor_context
vector._create_collection_if_not_exists(embedding_dimension=3)
vector.create_collection_if_not_exists(embedding_dimension=3)
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat
vector._get_cursor = _cursor_context
with pytest.raises(RuntimeError, match="permission denied"):
vector._create_collection_if_not_exists(embedding_dimension=3)
vector.create_collection_if_not_exists(embedding_dimension=3)
def test_delete_methods_raise_when_error_is_not_missing_table():