mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'main' into feat/support-extractor-tools
This commit is contained in:
@ -76,6 +76,7 @@ class BaseAppGenerator:
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
|
||||
if not user_input_value:
|
||||
if var.required:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
@ -88,6 +89,7 @@ class BaseAppGenerator:
|
||||
VariableEntityType.PARAGRAPH,
|
||||
} and not isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
@ -97,25 +99,30 @@ class BaseAppGenerator:
|
||||
return int(user_input_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
elif var.type == VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
elif var.type == VariableEntityType.FILE_LIST:
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
|
||||
match var.type:
|
||||
case VariableEntityType.SELECT:
|
||||
if user_input_value not in var.options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {var.options}")
|
||||
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
case VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
case VariableEntityType.FILE_LIST:
|
||||
# if number of files exceeds the limit, raise ValueError
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} files")
|
||||
|
||||
return user_input_value
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -597,26 +598,9 @@ class IndexingRunner:
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
document_text = CleanProcessor.clean(text, rules)
|
||||
|
||||
if "pre_processing_rules" in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
|
||||
# Remove extra spaces
|
||||
pattern = r"\n{3,}"
|
||||
text = re.sub(pattern, "\n\n", text)
|
||||
pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}"
|
||||
text = re.sub(pattern, " ", text)
|
||||
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
|
||||
# Remove email
|
||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r"https?://[^\s]+"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
return text
|
||||
return document_text
|
||||
|
||||
@staticmethod
|
||||
def format_split_text(text):
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 277 KiB |
@ -0,0 +1,15 @@
|
||||
<svg width="68" height="24" viewBox="0 0 68 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Gemini">
|
||||
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M50.6875 4.37014C48.3498 4.59292 46.5349 6.41319 46.3337 8.72764C46.1446 6.44662 44.2677 4.56074 41.9805 4.3737C44.2762 4.1997 46.152 2.28299 46.3373 0C46.4882 2.28911 48.405 4.20047 50.6875 4.37014ZM15.4567 9.41141L13.9579 10.9076C9.92941 6.64892 2.69298 9.97287 3.17317 15.8112C3.22394 23.108 14.5012 24.4317 15.3628 16.8809H9.52096L9.50061 14.9149H17.3595C18.8163 23.1364 8.44367 27.0292 3.19453 21.238C0.847044 18.7556 0.363651 14.7682 1.83717 11.7212C4.1129 6.62089 11.6505 5.29845 15.4567 9.41141ZM45.5915 23.5989H47.6945C47.6944 22.9155 47.6945 22.2307 47.6946 21.5452V21.5325C47.6948 19.8907 47.695 18.2453 47.6924 16.6072C47.6914 15.9407 47.6161 15.2823 47.4024 14.647C46.4188 11.2828 41.4255 11.4067 39.8332 14.214C38.5637 11.4171 34.4009 11.5236 32.8538 14.0084L32.8082 13.9976V12.4806L32.4233 12.4804H32.4224C31.8687 12.4801 31.3324 12.4798 30.7949 12.4811V23.5848L32.8977 23.5672C32.8981 22.9411 32.8979 22.3122 32.8977 21.6822V21.6812V21.6802V21.6791V21.6781V21.6771V21.676V21.676V21.6759V21.6758V21.6757V21.6757V21.6756C32.8973 20.204 32.8969 18.7261 32.904 17.2614C32.8889 15.3646 34.5674 13.5687 36.5358 14.124C37.7794 14.3298 38.1851 15.6148 38.1761 16.7257C38.1821 17.7019 38.18 18.6824 38.178 19.6633V19.6638C38.1752 20.9756 38.1724 22.2881 38.1891 23.5919L40.2846 23.5731C40.2929 22.7511 40.2881 21.9245 40.2832 21.0966C40.2753 19.7402 40.2674 18.3805 40.317 17.0328C40.4418 15.2122 42.0141 13.6186 43.9064 14.1168C45.2685 14.3231 45.6136 15.7748 45.5882 16.9545C45.5938 18.4959 45.5929 20.0492 45.5921 21.5968V21.5991V21.6014V21.6037V21.606V21.6083V21.6106C45.5917 22.2749 45.5913 22.9382 45.5915 23.5989ZM20.6167 18.4408C20.5625 21.9486 25.2121 23.6996 27.2993 20.0558L29.1566 20.9592C27.8157 23.7067 24.2337 24.7424 21.5381 23.4213C18.0052 21.7253 17.41 16.5007 20.0334 13.7517C21.4609 12.1752 23.7291 11.7901 25.7206 12.3653C28.3408 13.1257 29.4974 15.8937 29.326 18.4399C27.5547 18.4415 25.7971 18.4412 24.0364 18.4409C22.8993 18.4407 21.7609 18.4405 20.6167 18.4408ZM27.1041 16.6957C26.7048 13.1033 21.2867 13.2256 20.7494 16.6957H27.1041ZM53.543 23.5999H55.6206L55.6206 22.4361C55.6205 20.7877 55.6205 19.1443 55.6207 17.4939C55.6208 16.8853 55.7234 16.297 56.0063 15.7531C56.6115 14.3862 58.1745 13.7002 59.5927 14.1774C60.7512 14.4455 61.2852 15.6069 61.2762 16.7154C61.2774 18.3497 61.2771 19.9826 61.2769 21.6162V21.6166V21.617V21.6174V21.6179L61.2766 23.6007H63.3698C63.3913 22.0924 63.3869 20.584 63.3826 19.0755V19.0754V19.0753V19.0753V19.0752C63.3799 18.1682 63.3773 17.2612 63.3803 16.3541C63.3796 15.8622 63.3103 15.3765 63.1698 14.9052C62.3248 11.5142 57.3558 11.2385 55.5828 14.0038L55.5336 13.9905V12.4917H53.539C53.4898 12.7313 53.4934 23.4113 53.543 23.5999ZM49.6211 12.4944H51.7065V23.5994H49.6211V12.4944ZM65.1035 23.5991H67.1831C67.2367 23.2198 67.2133 12.6566 67.1634 12.4983H65.1035V23.5991ZM52.1504 8.67829C52.1709 10.4847 49.2418 10.7058 49.1816 8.65714C49.2189 6.5948 52.2437 6.81331 52.1504 8.67829ZM66.1387 10.1324C64.2712 10.1609 64.1316 7.19881 66.1559 7.17114C68.1709 7.19817 68.0215 10.2087 66.1387 10.1324Z" fill="url(#paint0_linear_14286_118464)"/>
|
||||
</g>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_14286_118464" x1="-2" y1="0.999998" x2="67.9999" y2="27.5002" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#7798E0"/>
|
||||
<stop offset="0.210002" stop-color="#086FFF"/>
|
||||
<stop offset="0.345945" stop-color="#086FFF"/>
|
||||
<stop offset="0.591777" stop-color="#479AFF"/>
|
||||
<stop offset="0.895892" stop-color="#B7C4FA"/>
|
||||
<stop offset="1" stop-color="#B5C5F9"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.6 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 57 KiB |
@ -0,0 +1,11 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="24" height="24" rx="6" fill="url(#paint0_linear_7301_16076)"/>
|
||||
<path d="M20 12.0116C15.7043 12.42 12.3692 15.757 11.9995 20C11.652 15.8183 8.20301 12.361 4 12.0181C8.21855 11.6991 11.6656 8.1853 12.006 4C12.2833 8.19653 15.8057 11.7005 20 12.0116Z" fill="white" fill-opacity="0.88"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_7301_16076" x1="-9" y1="29.5" x2="19.4387" y2="1.43791" gradientUnits="userSpaceOnUse">
|
||||
<stop offset="0.192878" stop-color="#1C7DFF"/>
|
||||
<stop offset="0.520213" stop-color="#1C69FF"/>
|
||||
<stop offset="1" stop-color="#F0DCD6"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 689 B |
10
api/core/model_runtime/model_providers/gpustack/gpustack.py
Normal file
10
api/core/model_runtime/model_providers/gpustack/gpustack.py
Normal file
@ -0,0 +1,10 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPUStackProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
120
api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Normal file
120
api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Normal file
@ -0,0 +1,120 @@
|
||||
provider: gpustack
|
||||
label:
|
||||
en_US: GPUStack
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
zh_Hans: 服务器地址
|
||||
en_US: Server URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
|
||||
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择补全类型
|
||||
en_US: Select completion type
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: "8192"
|
||||
placeholder:
|
||||
zh_Hans: 输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens_to_sample
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: "8192"
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Function Call
|
||||
zh_Hans: Function Call
|
||||
- value: tool_call
|
||||
label:
|
||||
en_US: Tool Call
|
||||
zh_Hans: Tool Call
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: vision_support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: Vision 支持
|
||||
en_US: Vision Support
|
||||
type: select
|
||||
required: false
|
||||
default: no_support
|
||||
options:
|
||||
- value: support
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
- value: no_support
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
45
api/core/model_runtime/model_providers/gpustack/llm/llm.py
Normal file
45
api/core/model_runtime/model_providers/gpustack/llm/llm.py
Normal file
@ -0,0 +1,45 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
|
||||
OAIAPICompatLargeLanguageModel,
|
||||
)
|
||||
|
||||
|
||||
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return super()._invoke(
|
||||
model,
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
stop,
|
||||
stream,
|
||||
user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
credentials["mode"] = "chat"
|
||||
146
api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
Normal file
146
api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
Normal file
@ -0,0 +1,146 @@
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class GPUStackRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for GPUStack rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
|
||||
|
||||
try:
|
||||
response = post(
|
||||
str(URL(endpoint_url) / "v1" / "rerank"),
|
||||
headers=headers,
|
||||
data=dumps(data),
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results["results"]:
|
||||
index = result["index"]
|
||||
if "document" in result:
|
||||
text = result["document"]["text"]
|
||||
else:
|
||||
text = docs[index]
|
||||
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=text,
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
TextEmbeddingResult,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||
OAICompatEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
|
||||
"""
|
||||
Model class for GPUStack text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
return super()._invoke(model, credentials, texts, user, input_type)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
|
||||
@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
|
||||
0
api/core/rag/datasource/vdb/lindorm/__init__.py
Normal file
0
api/core/rag/datasource/vdb/lindorm/__init__.py
Normal file
498
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Normal file
498
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Normal file
@ -0,0 +1,498 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger("lindorm").setLevel(logging.WARN)
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["hosts"]:
|
||||
raise ValueError("config URL is required")
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts,
|
||||
}
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
|
||||
class LindormVectorStore(BaseVector):
|
||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
||||
super().__init__(collection_name.lower())
|
||||
self._client_config = config
|
||||
self._client = OpenSearch(**config.to_opensearch_params())
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.LINDORM
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.create_collection(len(embeddings[0]), **kwargs)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def refresh(self):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
||||
def __filter_existed_ids(
|
||||
self,
|
||||
texts: list[str],
|
||||
metadatas: list[dict],
|
||||
ids: list[str],
|
||||
bulk_size: int = 1024,
|
||||
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
|
||||
try:
|
||||
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
|
||||
try:
|
||||
existing_docs = self._client.mget(
|
||||
body={
|
||||
"docs": [
|
||||
{"_index": self._collection_name, "_id": id, "routing": routing}
|
||||
for id, routing in zip(batch_ids, route_ids)
|
||||
]
|
||||
},
|
||||
_source=False,
|
||||
)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
if ids is None:
|
||||
return texts, metadatas, ids
|
||||
|
||||
if len(texts) != len(ids):
|
||||
raise RuntimeError(f"texts {len(texts)} != {ids}")
|
||||
|
||||
filtered_texts = []
|
||||
filtered_metadatas = []
|
||||
filtered_ids = []
|
||||
|
||||
def batch(iterable, n):
|
||||
length = len(iterable)
|
||||
for idx in range(0, length, n):
|
||||
yield iterable[idx : min(idx + n, length)]
|
||||
|
||||
for ids_batch, texts_batch, metadatas_batch in zip(
|
||||
batch(ids, bulk_size),
|
||||
batch(texts, bulk_size),
|
||||
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
|
||||
):
|
||||
existing_ids_set = __fetch_existing_ids(ids_batch)
|
||||
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
|
||||
if doc_id not in existing_ids_set:
|
||||
filtered_texts.append(text)
|
||||
filtered_ids.append(doc_id)
|
||||
if metadatas is not None:
|
||||
filtered_metadatas.append(metadata)
|
||||
|
||||
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
actions = []
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
action = {
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
bulk(self._client, actions)
|
||||
self.refresh()
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
for id in ids:
|
||||
if self._client.exists(index=self._collection_name, id=id):
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
else:
|
||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
||||
logger.info("Delete index success")
|
||||
else:
|
||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while deleting the index: {e}")
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
self._client.get(index=self._collection_name, id=id)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
# Make sure query_vector is a list
|
||||
if not isinstance(query_vector, list):
|
||||
raise ValueError("query_vector should be a list of floats")
|
||||
|
||||
# Check whether query_vector is a floating-point number list
|
||||
if not all(isinstance(x, float) for x in query_vector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing search: {e}")
|
||||
raise
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
must = kwargs.get("must")
|
||||
must_not = kwargs.get("must_not")
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter")
|
||||
routing = kwargs.get("routing")
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
k=top_k,
|
||||
text_field=Field.CONTENT_KEY.value,
|
||||
must=must,
|
||||
must_not=must_not,
|
||||
should=should,
|
||||
minimum_should_match=minimum_should_match,
|
||||
filters=filters,
|
||||
routing=routing,
|
||||
)
|
||||
response = self._client.search(index=self._collection_name, body=full_text_query)
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(self, dimension: int, **kwargs):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
logger.info("{self._collection_name.lower()} already exists.")
|
||||
return
|
||||
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
||||
self.kwargs = copy.deepcopy(kwargs)
|
||||
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
||||
shards = kwargs.pop("shards", 2)
|
||||
|
||||
engine = kwargs.pop("engine", "lvector")
|
||||
method_name = kwargs.pop("method_name", "hnsw")
|
||||
data_type = kwargs.pop("data_type", "float")
|
||||
space_type = kwargs.pop("space_type", "cosinesimil")
|
||||
|
||||
hnsw_m = kwargs.pop("hnsw_m", 24)
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
mapping = default_text_mapping(
|
||||
dimension,
|
||||
method_name,
|
||||
shards=shards,
|
||||
engine=engine,
|
||||
data_type=data_type,
|
||||
space_type=space_type,
|
||||
vector_field=vector_field,
|
||||
hnsw_m=hnsw_m,
|
||||
hnsw_ef_construction=hnsw_ef_construction,
|
||||
nlist=nlist,
|
||||
ivfpq_m=ivfpq_m,
|
||||
centroids_use_hnsw=centroids_use_hnsw,
|
||||
centroids_hnsw_m=centroids_hnsw_m,
|
||||
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
||||
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
||||
**kwargs,
|
||||
)
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
# logger.info(f"create index success: {self._collection_name}")
|
||||
|
||||
|
||||
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
|
||||
routing_field = kwargs.get("routing_field")
|
||||
excludes_from_source = kwargs.get("excludes_from_source")
|
||||
analyzer = kwargs.get("analyzer", "ik_max_word")
|
||||
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
||||
engine = kwargs["engine"]
|
||||
shard = kwargs["shards"]
|
||||
space_type = kwargs["space_type"]
|
||||
data_type = kwargs["data_type"]
|
||||
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
||||
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = True if nlist > 10000 else False
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
parameters = {
|
||||
"m": ivfpq_m,
|
||||
"nlist": nlist,
|
||||
"centroids_use_hnsw": centroids_use_hnsw,
|
||||
"centroids_hnsw_m": centroids_hnsw_m,
|
||||
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
|
||||
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
|
||||
}
|
||||
elif method_name == "hnsw":
|
||||
neighbor = kwargs["hnsw_m"]
|
||||
ef_construction = kwargs["hnsw_ef_construction"]
|
||||
parameters = {"m": neighbor, "ef_construction": ef_construction}
|
||||
elif method_name == "flat":
|
||||
parameters = {}
|
||||
else:
|
||||
raise RuntimeError(f"unexpected method_name: {method_name}")
|
||||
|
||||
mapping = {
|
||||
"settings": {"index": {"number_of_shards": shard, "knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
vector_field: {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension,
|
||||
"data_type": data_type,
|
||||
"method": {
|
||||
"engine": engine,
|
||||
"name": method_name,
|
||||
"space_type": space_type,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
text_field: {"type": "text", "analyzer": analyzer},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if excludes_from_source:
|
||||
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
||||
|
||||
if method_name == "ivfpq" and routing_field is not None:
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
mapping["settings"]["index"]["knn.offline.construction"] = True
|
||||
|
||||
if method_name == "flat" and routing_field is not None:
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def default_text_search_query(
|
||||
query_text: str,
|
||||
k: int = 4,
|
||||
text_field: str = Field.CONTENT_KEY.value,
|
||||
must: Optional[list[dict]] = None,
|
||||
must_not: Optional[list[dict]] = None,
|
||||
should: Optional[list[dict]] = None,
|
||||
minimum_should_match: int = 0,
|
||||
filters: Optional[list[dict]] = None,
|
||||
routing: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if routing is not None:
|
||||
routing_field = kwargs.get("routing_field", "routing_field")
|
||||
query_clause = {
|
||||
"bool": {
|
||||
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
|
||||
}
|
||||
}
|
||||
else:
|
||||
query_clause = {"match": {text_field: query_text}}
|
||||
# build the simplest search_query when only query_text is specified
|
||||
if not must and not must_not and not should and not filters:
|
||||
search_query = {"size": k, "query": query_clause}
|
||||
return search_query
|
||||
|
||||
# build complex search_query when either of must/must_not/should/filter is specified
|
||||
if must:
|
||||
if not isinstance(must, list):
|
||||
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
|
||||
if query_clause not in must:
|
||||
must.append(query_clause)
|
||||
else:
|
||||
must = [query_clause]
|
||||
|
||||
boolean_query = {"must": must}
|
||||
|
||||
if must_not:
|
||||
if not isinstance(must_not, list):
|
||||
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
|
||||
boolean_query["must_not"] = must_not
|
||||
|
||||
if should:
|
||||
if not isinstance(should, list):
|
||||
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
|
||||
boolean_query["should"] = should
|
||||
if minimum_should_match != 0:
|
||||
boolean_query["minimum_should_match"] = minimum_should_match
|
||||
|
||||
if filters:
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
|
||||
boolean_query["filter"] = filters
|
||||
|
||||
search_query = {"size": k, "query": {"bool": boolean_query}}
|
||||
return search_query
|
||||
|
||||
|
||||
def default_vector_search_query(
|
||||
query_vector: list[float],
|
||||
k: int = 4,
|
||||
min_score: str = "0.0",
|
||||
ef_search: Optional[str] = None, # only for hnsw
|
||||
nprobe: Optional[str] = None, # "2000"
|
||||
reorder_factor: Optional[str] = None, # "20"
|
||||
client_refactor: Optional[str] = None, # "true"
|
||||
vector_field: str = Field.VECTOR.value,
|
||||
filters: Optional[list[dict]] = None,
|
||||
filter_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if filters is not None:
|
||||
filter_type = "post_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filter, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext = {"lvector": {}}
|
||||
if min_score != "0.0":
|
||||
final_ext["lvector"]["min_score"] = min_score
|
||||
if ef_search:
|
||||
final_ext["lvector"]["ef_search"] = ef_search
|
||||
if nprobe:
|
||||
final_ext["lvector"]["nprobe"] = nprobe
|
||||
if reorder_factor:
|
||||
final_ext["lvector"]["reorder_factor"] = reorder_factor
|
||||
if client_refactor:
|
||||
final_ext["lvector"]["client_refactor"] = client_refactor
|
||||
|
||||
search_query = {
|
||||
"size": k,
|
||||
"_source": True, # force return '_source'
|
||||
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
|
||||
}
|
||||
|
||||
if filters is not None:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
|
||||
if filter_type:
|
||||
final_ext["lvector"]["filter_type"] = filter_type
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
search_query["ext"] = final_ext
|
||||
return search_query
|
||||
|
||||
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
username=dify_config.LINDORM_USERNAME,
|
||||
password=dify_config.LINDORM_PASSWORD,
|
||||
)
|
||||
return LindormVectorStore(collection_name, lindorm_config)
|
||||
@ -134,6 +134,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.LINDORM:
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ class VectorType(str, Enum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from matplotlib.font_manager import FontProperties, fontManager
|
||||
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -17,9 +17,10 @@ def set_chinese_font():
|
||||
]
|
||||
|
||||
for font in font_list:
|
||||
chinese_font = FontProperties(font)
|
||||
if chinese_font.get_name() == font:
|
||||
return chinese_font
|
||||
if font in fontManager.ttflist:
|
||||
chinese_font = FontProperties(font)
|
||||
if chinese_font.get_name() == font:
|
||||
return chinese_font
|
||||
|
||||
return FontProperties()
|
||||
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import openai
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
class PodcastAudioGeneratorTool(BuiltinTool):
|
||||
@staticmethod
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class DocumentExtractorError(Exception):
|
||||
class DocumentExtractorError(ValueError):
|
||||
"""Base exception for errors related to the DocumentExtractorNode."""
|
||||
|
||||
|
||||
|
||||
@ -6,12 +6,14 @@ import docx
|
||||
import pandas as pd
|
||||
import pypdfium2
|
||||
import yaml
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
@ -263,7 +265,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_pptx(file=file)
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
else:
|
||||
elements = partition_pptx(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
16
api/core/workflow/nodes/list_operator/exc.py
Normal file
16
api/core/workflow/nodes/list_operator/exc.py
Normal file
@ -0,0 +1,16 @@
|
||||
class ListOperatorError(ValueError):
|
||||
"""Base class for all ListOperator errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFilterValueError(ListOperatorError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidKeyError(ListOperatorError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidConditionError(ListOperatorError):
|
||||
pass
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Literal
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
if not variable.value:
|
||||
inputs = {"variable": []}
|
||||
process_data = {"variable": []}
|
||||
outputs = {"result": [], "first_record": None, "last_record": None}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
)
|
||||
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
inputs = {"variable": [item.to_dict() for item in variable.value]}
|
||||
process_data["variable"] = [item.to_dict() for item in variable.value]
|
||||
else:
|
||||
inputs = {"variable": variable.value}
|
||||
process_data["variable"] = variable.value
|
||||
|
||||
# Filter
|
||||
if self.node_data.filter_by.enabled:
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise ValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise ValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
if isinstance(condition.value, str):
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
else:
|
||||
value = condition.value
|
||||
filter_func = _get_file_filter_func(
|
||||
key=condition.key,
|
||||
condition=condition.comparison_operator,
|
||||
value=value,
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
try:
|
||||
# Filter
|
||||
if self.node_data.filter_by.enabled:
|
||||
variable = self._apply_filter(variable)
|
||||
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
variable = self._apply_order(variable)
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
except ListOperatorError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
if isinstance(condition.value, str):
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
else:
|
||||
value = condition.value
|
||||
filter_func = _get_file_filter_func(
|
||||
key=condition.key,
|
||||
condition=condition.comparison_operator,
|
||||
value=value,
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
|
||||
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
||||
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
||||
case "size":
|
||||
return lambda x: x.size
|
||||
case _:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case _:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
|
||||
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
|
||||
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
|
||||
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
case "≥":
|
||||
return _ge(value)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
else:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _contains(value: str):
|
||||
|
||||
Reference in New Issue
Block a user