mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
refactor(api): type Chroma and AnalyticDB config params dicts with TypedDicts (#34678)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@ -13,6 +13,13 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbClientParamsDict(TypedDict):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
read_timeout: int
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
|
||||
result: AnalyticdbClientParamsDict = {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class ChromaParamsDict(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
ssl: bool
|
||||
tenant: str
|
||||
database: str
|
||||
settings: Settings
|
||||
|
||||
|
||||
class ChromaConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
|
||||
auth_provider: str | None = None
|
||||
auth_credentials: str | None = None
|
||||
|
||||
def to_chroma_params(self):
|
||||
def to_chroma_params(self) -> ChromaParamsDict:
|
||||
settings = Settings(
|
||||
# auth
|
||||
chroma_client_auth_provider=self.auth_provider,
|
||||
chroma_client_auth_credentials=self.auth_credentials,
|
||||
)
|
||||
|
||||
return {
|
||||
result: ChromaParamsDict = {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"ssl": False,
|
||||
@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
|
||||
"database": self.database,
|
||||
"settings": settings,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class ChromaVector(BaseVector):
|
||||
|
||||
Reference in New Issue
Block a user