Merge branch 'main' into feat/support-extractor-tools

This commit is contained in:
jyong
2024-11-05 14:44:42 +08:00
127 changed files with 2361 additions and 441 deletions

View File

@ -76,6 +76,7 @@ class BaseAppGenerator:
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
user_input_value = inputs.get(var.variable)
if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form")
@ -88,6 +89,7 @@ class BaseAppGenerator:
VariableEntityType.PARAGRAPH,
} and not isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
@ -97,25 +99,30 @@ class BaseAppGenerator:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
elif var.type == VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
elif var.type == VariableEntityType.FILE_LIST:
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
match var.type:
case VariableEntityType.SELECT:
if user_input_value not in var.options:
raise ValueError(f"{var.variable} in input form must be one of the following: {var.options}")
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
case VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
case VariableEntityType.FILE_LIST:
# if number of files exceeds the limit, raise ValueError
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} files")
return user_input_value

View File

@ -17,6 +17,7 @@ from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -597,26 +598,9 @@ class IndexingRunner:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
document_text = CleanProcessor.clean(text, rules)
if "pre_processing_rules" in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r"\n{3,}"
text = re.sub(pattern, "\n\n", text)
pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}"
text = re.sub(pattern, " ", text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL
pattern = r"https?://[^\s]+"
text = re.sub(pattern, "", text)
return text
return document_text
@staticmethod
def format_split_text(text):

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

View File

@ -0,0 +1,15 @@
<svg width="68" height="24" viewBox="0 0 68 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Gemini">
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M50.6875 4.37014C48.3498 4.59292 46.5349 6.41319 46.3337 8.72764C46.1446 6.44662 44.2677 4.56074 41.9805 4.3737C44.2762 4.1997 46.152 2.28299 46.3373 0C46.4882 2.28911 48.405 4.20047 50.6875 4.37014ZM15.4567 9.41141L13.9579 10.9076C9.92941 6.64892 2.69298 9.97287 3.17317 15.8112C3.22394 23.108 14.5012 24.4317 15.3628 16.8809H9.52096L9.50061 14.9149H17.3595C18.8163 23.1364 8.44367 27.0292 3.19453 21.238C0.847044 18.7556 0.363651 14.7682 1.83717 11.7212C4.1129 6.62089 11.6505 5.29845 15.4567 9.41141ZM45.5915 23.5989H47.6945C47.6944 22.9155 47.6945 22.2307 47.6946 21.5452V21.5325C47.6948 19.8907 47.695 18.2453 47.6924 16.6072C47.6914 15.9407 47.6161 15.2823 47.4024 14.647C46.4188 11.2828 41.4255 11.4067 39.8332 14.214C38.5637 11.4171 34.4009 11.5236 32.8538 14.0084L32.8082 13.9976V12.4806L32.4233 12.4804H32.4224C31.8687 12.4801 31.3324 12.4798 30.7949 12.4811V23.5848L32.8977 23.5672C32.8981 22.9411 32.8979 22.3122 32.8977 21.6822V21.6812V21.6802V21.6791V21.6781V21.6771V21.676V21.676V21.6759V21.6758V21.6757V21.6757V21.6756C32.8973 20.204 32.8969 18.7261 32.904 17.2614C32.8889 15.3646 34.5674 13.5687 36.5358 14.124C37.7794 14.3298 38.1851 15.6148 38.1761 16.7257C38.1821 17.7019 38.18 18.6824 38.178 19.6633V19.6638C38.1752 20.9756 38.1724 22.2881 38.1891 23.5919L40.2846 23.5731C40.2929 22.7511 40.2881 21.9245 40.2832 21.0966C40.2753 19.7402 40.2674 18.3805 40.317 17.0328C40.4418 15.2122 42.0141 13.6186 43.9064 14.1168C45.2685 14.3231 45.6136 15.7748 45.5882 16.9545C45.5938 18.4959 45.5929 20.0492 45.5921 21.5968V21.5991V21.6014V21.6037V21.606V21.6083V21.6106C45.5917 22.2749 45.5913 22.9382 45.5915 23.5989ZM20.6167 18.4408C20.5625 21.9486 25.2121 23.6996 27.2993 20.0558L29.1566 20.9592C27.8157 23.7067 24.2337 24.7424 21.5381 23.4213C18.0052 21.7253 17.41 16.5007 20.0334 13.7517C21.4609 12.1752 23.7291 11.7901 25.7206 12.3653C28.3408 13.1257 29.4974 15.8937 29.326 18.4399C27.5547 18.4415 25.7971 18.4412 24.0364 18.4409C22.8993 18.4407 21.7609 18.4405 20.6167 18.4408ZM27.1041 16.6957C26.7048 13.1033 21.2867 13.2256 20.7494 16.6957H27.1041ZM53.543 23.5999H55.6206L55.6206 22.4361C55.6205 20.7877 55.6205 19.1443 55.6207 17.4939C55.6208 16.8853 55.7234 16.297 56.0063 15.7531C56.6115 14.3862 58.1745 13.7002 59.5927 14.1774C60.7512 14.4455 61.2852 15.6069 61.2762 16.7154C61.2774 18.3497 61.2771 19.9826 61.2769 21.6162V21.6166V21.617V21.6174V21.6179L61.2766 23.6007H63.3698C63.3913 22.0924 63.3869 20.584 63.3826 19.0755V19.0754V19.0753V19.0753V19.0752C63.3799 18.1682 63.3773 17.2612 63.3803 16.3541C63.3796 15.8622 63.3103 15.3765 63.1698 14.9052C62.3248 11.5142 57.3558 11.2385 55.5828 14.0038L55.5336 13.9905V12.4917H53.539C53.4898 12.7313 53.4934 23.4113 53.543 23.5999ZM49.6211 12.4944H51.7065V23.5994H49.6211V12.4944ZM65.1035 23.5991H67.1831C67.2367 23.2198 67.2133 12.6566 67.1634 12.4983H65.1035V23.5991ZM52.1504 8.67829C52.1709 10.4847 49.2418 10.7058 49.1816 8.65714C49.2189 6.5948 52.2437 6.81331 52.1504 8.67829ZM66.1387 10.1324C64.2712 10.1609 64.1316 7.19881 66.1559 7.17114C68.1709 7.19817 68.0215 10.2087 66.1387 10.1324Z" fill="url(#paint0_linear_14286_118464)"/>
</g>
<defs>
<linearGradient id="paint0_linear_14286_118464" x1="-2" y1="0.999998" x2="67.9999" y2="27.5002" gradientUnits="userSpaceOnUse">
<stop stop-color="#7798E0"/>
<stop offset="0.210002" stop-color="#086FFF"/>
<stop offset="0.345945" stop-color="#086FFF"/>
<stop offset="0.591777" stop-color="#479AFF"/>
<stop offset="0.895892" stop-color="#B7C4FA"/>
<stop offset="1" stop-color="#B5C5F9"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

View File

@ -0,0 +1,11 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="6" fill="url(#paint0_linear_7301_16076)"/>
<path d="M20 12.0116C15.7043 12.42 12.3692 15.757 11.9995 20C11.652 15.8183 8.20301 12.361 4 12.0181C8.21855 11.6991 11.6656 8.1853 12.006 4C12.2833 8.19653 15.8057 11.7005 20 12.0116Z" fill="white" fill-opacity="0.88"/>
<defs>
<linearGradient id="paint0_linear_7301_16076" x1="-9" y1="29.5" x2="19.4387" y2="1.43791" gradientUnits="userSpaceOnUse">
<stop offset="0.192878" stop-color="#1C7DFF"/>
<stop offset="0.520213" stop-color="#1C69FF"/>
<stop offset="1" stop-color="#F0DCD6"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 689 B

View File

@ -0,0 +1,10 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class GPUStackProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,120 @@
provider: gpustack
label:
en_US: GPUStack
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: endpoint_url
label:
zh_Hans: 服务器地址
en_US: Server URL
type: text-input
required: true
placeholder:
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 输入您的 API Key
en_US: Enter your API Key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择补全类型
en_US: Select completion type
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: "8192"
placeholder:
zh_Hans: 输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens_to_sample
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: "8192"
type: text-input
- variable: function_calling_type
show_on:
- variable: __model_type
value: llm
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: function_call
label:
en_US: Function Call
zh_Hans: Function Call
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: vision_support
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: Vision 支持
en_US: Vision Support
type: select
required: false
default: no_support
options:
- value: support
label:
en_US: Support
zh_Hans: 支持
- value: no_support
label:
en_US: Not Support
zh_Hans: 不支持

View File

@ -0,0 +1,45 @@
from collections.abc import Generator
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
OAIAPICompatLargeLanguageModel,
)
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return super()._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"

View File

@ -0,0 +1,146 @@
from json import dumps
from typing import Optional
import httpx
from requests import post
from yarl import URL
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class GPUStackRerankModel(RerankModel):
"""
Model class for GPUStack rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
endpoint_url = credentials["endpoint_url"]
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
"Content-Type": "application/json",
}
data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
try:
response = post(
str(URL(endpoint_url) / "v1" / "rerank"),
headers=headers,
data=dumps(data),
timeout=10,
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results["results"]:
index = result["index"]
if "document" in result:
text = result["document"]["text"]
else:
text = docs[index]
rerank_document = RerankDocument(
index=index,
text=text,
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -0,0 +1,35 @@
from typing import Optional
from yarl import URL
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel,
)
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
"""
Model class for GPUStack text embedding model.
"""
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")

View File

@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 131072

View File

@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 131072

View File

@ -0,0 +1,498 @@
import copy
import json
import logging
from collections.abc import Iterable
from typing import Any, Optional
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
class LindormVectorStoreConfig(BaseModel):
hosts: str
username: Optional[str] = None
password: Optional[str] = None
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["hosts"]:
raise ValueError("config URL is required")
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts,
}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self.kwargs = kwargs
def get_type(self) -> str:
return VectorType.LINDORM
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]), **kwargs)
self.add_texts(texts, embeddings)
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.error(f"Error fetching batch {batch_ids}: {e}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name, "_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.error(f"Error fetching batch {batch_ids}: {e}")
return set()
if ids is None:
return texts, metadatas, ids
if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")
filtered_texts = []
filtered_metadatas = []
filtered_ids = []
def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx : min(idx + n, length)]
for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuids[i],
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
actions.append(action)
bulk(self._client, actions)
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:
if self._client.exists(index=self._collection_name, id=id):
self._client.delete(index=self._collection_name, id=id)
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None:
try:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.error(f"Error occurred while deleting the index: {e}")
raise e
def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name, id=id)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try:
response = self._client.search(index=self._collection_name, body=query)
except Exception as e:
logger.error(f"Error executing search: {e}")
raise
docs_and_scores = []
for hit in response["hits"]["hits"]:
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
must = kwargs.get("must")
must_not = kwargs.get("must_not")
should = kwargs.get("should")
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
routing = kwargs.get("routing")
full_text_query = default_text_search_query(
query_text=query,
k=top_k,
text_field=Field.CONTENT_KEY.value,
must=must,
must_not=must_not,
should=should,
minimum_should_match=minimum_should_match,
filters=filters,
routing=routing,
)
response = self._client.search(index=self._collection_name, body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
return docs
def create_collection(self, dimension: int, **kwargs):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.")
return
if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2)
engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw")
data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")
hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
mapping = default_text_mapping(
dimension,
method_name,
shards=shards,
engine=engine,
data_type=data_type,
space_type=space_type,
vector_field=vector_field,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
nlist=nlist,
ivfpq_m=ivfpq_m,
centroids_use_hnsw=centroids_use_hnsw,
centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
**kwargs,
)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
# logger.info(f"create index success: {self._collection_name}")
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"]
shard = kwargs["shards"]
space_type = kwargs["space_type"]
data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
nlist = kwargs["nlist"]
centroids_use_hnsw = True if nlist > 10000 else False
centroids_hnsw_m = 24
centroids_hnsw_ef_construct = 500
centroids_hnsw_ef_search = 100
parameters = {
"m": ivfpq_m,
"nlist": nlist,
"centroids_use_hnsw": centroids_use_hnsw,
"centroids_hnsw_m": centroids_hnsw_m,
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
}
elif method_name == "hnsw":
neighbor = kwargs["hnsw_m"]
ef_construction = kwargs["hnsw_ef_construction"]
parameters = {"m": neighbor, "ef_construction": ef_construction}
elif method_name == "flat":
parameters = {}
else:
raise RuntimeError(f"unexpected method_name: {method_name}")
mapping = {
"settings": {"index": {"number_of_shards": shard, "knn": True}},
"mappings": {
"properties": {
vector_field: {
"type": "knn_vector",
"dimension": dimension,
"data_type": data_type,
"method": {
"engine": engine,
"name": method_name,
"space_type": space_type,
"parameters": parameters,
},
},
text_field: {"type": "text", "analyzer": analyzer},
}
},
}
if excludes_from_source:
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
if method_name == "ivfpq" and routing_field is not None:
mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True
if method_name == "flat" and routing_field is not None:
mapping["settings"]["index"]["knn_routing"] = True
return mapping
def default_text_search_query(
query_text: str,
k: int = 4,
text_field: str = Field.CONTENT_KEY.value,
must: Optional[list[dict]] = None,
must_not: Optional[list[dict]] = None,
should: Optional[list[dict]] = None,
minimum_should_match: int = 0,
filters: Optional[list[dict]] = None,
routing: Optional[str] = None,
**kwargs,
) -> dict:
if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = {
"bool": {
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
}
else:
query_clause = {"match": {text_field: query_text}}
# build the simplest search_query when only query_text is specified
if not must and not must_not and not should and not filters:
search_query = {"size": k, "query": query_clause}
return search_query
# build complex search_query when either of must/must_not/should/filter is specified
if must:
if not isinstance(must, list):
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
if query_clause not in must:
must.append(query_clause)
else:
must = [query_clause]
boolean_query = {"must": must}
if must_not:
if not isinstance(must_not, list):
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
boolean_query["must_not"] = must_not
if should:
if not isinstance(should, list):
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
boolean_query["should"] = should
if minimum_should_match != 0:
boolean_query["minimum_should_match"] = minimum_should_match
if filters:
if not isinstance(filters, list):
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
boolean_query["filter"] = filters
search_query = {"size": k, "query": {"bool": boolean_query}}
return search_query
def default_vector_search_query(
query_vector: list[float],
k: int = 4,
min_score: str = "0.0",
ef_search: Optional[str] = None, # only for hnsw
nprobe: Optional[str] = None, # "2000"
reorder_factor: Optional[str] = None, # "20"
client_refactor: Optional[str] = None, # "true"
vector_field: str = Field.VECTOR.value,
filters: Optional[list[dict]] = None,
filter_type: Optional[str] = None,
**kwargs,
) -> dict:
if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filter, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
if min_score != "0.0":
final_ext["lvector"]["min_score"] = min_score
if ef_search:
final_ext["lvector"]["ef_search"] = ef_search
if nprobe:
final_ext["lvector"]["nprobe"] = nprobe
if reorder_factor:
final_ext["lvector"]["reorder_factor"] = reorder_factor
if client_refactor:
final_ext["lvector"]["client_refactor"] = client_refactor
search_query = {
"size": k,
"_source": True, # force return '_source'
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
}
if filters is not None:
# when using filter, transform filter from List[Dict] to Dict as valid format
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
if filter_type:
final_ext["lvector"]["filter_type"] = filter_type
if final_ext != {"lvector": {}}:
search_query["ext"] = final_ext
return search_query
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
)
return LindormVectorStore(collection_name, lindorm_config)

View File

@ -134,6 +134,10 @@ class Vector:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"
VIKINGDB = "vikingdb"

View File

@ -1,5 +1,5 @@
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.font_manager import FontProperties, fontManager
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -17,9 +17,10 @@ def set_chinese_font():
]
for font in font_list:
chinese_font = FontProperties(font)
if chinese_font.get_name() == font:
return chinese_font
if font in fontManager.ttflist:
chinese_font = FontProperties(font)
if chinese_font.get_name() == font:
return chinese_font
return FontProperties()

View File

@ -1,15 +1,19 @@
import concurrent.futures
import io
import random
import warnings
from typing import Any, Literal, Optional, Union
import openai
from pydub import AudioSegment
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from pydub import AudioSegment
class PodcastAudioGeneratorTool(BuiltinTool):
@staticmethod

View File

@ -1,4 +1,4 @@
class DocumentExtractorError(Exception):
class DocumentExtractorError(ValueError):
"""Base exception for errors related to the DocumentExtractorNode."""

View File

@ -6,12 +6,14 @@ import docx
import pandas as pd
import pypdfium2
import yaml
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
@ -263,7 +265,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file)
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
elements = partition_via_api(
file=file,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
)
else:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e

View File

@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""
pass
class InvalidFilterValueError(ListOperatorError):
pass
class InvalidKeyError(ListOperatorError):
pass
class InvalidConditionError(ListOperatorError):
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal
from typing import Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value
# Filter
if self.node_data.filter_by.enabled:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
try:
# Filter
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Order
if self.node_data.order_by.enabled:
# Order
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
return variable
# Slice
if self.node_data.limit.enabled:
result = variable.value[: self.node_data.limit.size]
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result})
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
case "size":
return lambda x: x.size
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "url":
return lambda x: x.remote_url or ""
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "not in":
return lambda x: not _in(value)(x)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
case "":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str):