Feat/plugins (#12547)

Co-authored-by: AkaraChen <akarachen@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com>
This commit is contained in:
zxhlyh
2025-01-09 18:47:41 +08:00
committed by GitHub
parent e4c4490175
commit 3c014f3ae5
719 changed files with 48585 additions and 8553 deletions

View File

@ -60,17 +60,20 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if response.status_code not in STATUS_FORCELIST:
return response
else:
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
logging.warning(
f"Received status code {response.status_code} for URL {url} which is in the force list")
except httpx.RequestError as e:
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
logging.warning(f"Request to URL {url} failed on attempt {
retries + 1}: {e}")
if max_retries == 0:
raise
retries += 1
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
raise MaxRetriesExceededError(
f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

View File

@ -17,7 +17,8 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.basicConfig(level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
ROUTING_FIELD = "routing_field"
@ -134,7 +135,8 @@ class LindormVectorStore(BaseVector):
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
logger.warning(
f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None:
if self._using_ugc:
@ -145,7 +147,8 @@ class LindormVectorStore(BaseVector):
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
self._client.indices.delete(
index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
@ -168,7 +171,8 @@ class LindormVectorStore(BaseVector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
query = default_vector_search_query(
query_vector=query_vector, k=top_k, **kwargs)
try:
params = {}
if self._using_ugc:
@ -220,7 +224,8 @@ class LindormVectorStore(BaseVector):
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(index=self._collection_name, body=full_text_query)
response = self._client.search(
index=self._collection_name, body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
@ -238,7 +243,8 @@ class LindormVectorStore(BaseVector):
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
logger.info(
f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info(f"{self._collection_name.lower()} already exists.")
@ -258,10 +264,13 @@ class LindormVectorStore(BaseVector):
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_use_hnsw = kwargs.pop(
"centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
centroids_hnsw_ef_construct = kwargs.pop(
"centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop(
"centroids_hnsw_ef_search", 100)
mapping = default_text_mapping(
dimension,
method_name,
@ -281,7 +290,8 @@ class LindormVectorStore(BaseVector):
using_ugc=self._using_ugc,
**kwargs,
)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
self._client.indices.create(
index=self._collection_name.lower(), body=mapping)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
# logger.info(f"create index success: {self._collection_name}")
@ -347,7 +357,8 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
}
if excludes_from_source:
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
# e.g. {"excludes": ["vector_field"]}
mapping["mappings"]["_source"] = {"excludes": excludes_from_source}
if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True
@ -385,7 +396,8 @@ def default_text_search_query(
# build complex search_query when either of must/must_not/should/filter is specified
if must:
if not isinstance(must, list):
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
raise RuntimeError(
f"unexpected [must] clause with {type(filters)}")
if query_clause not in must:
must.append(query_clause)
else:
@ -395,19 +407,22 @@ def default_text_search_query(
if must_not:
if not isinstance(must_not, list):
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
raise RuntimeError(
f"unexpected [must_not] clause with {type(filters)}")
boolean_query["must_not"] = must_not
if should:
if not isinstance(should, list):
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
raise RuntimeError(
f"unexpected [should] clause with {type(filters)}")
boolean_query["should"] = should
if minimum_should_match != 0:
boolean_query["minimum_should_match"] = minimum_should_match
if filters:
if not isinstance(filters, list):
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
raise RuntimeError(
f"unexpected [filter] clause with {type(filters)}")
boolean_query["filter"] = filters
search_query = {"size": k, "query": {"bool": boolean_query}}

View File

@ -50,7 +50,7 @@ class WordExtractor(BaseExtractor):
self.web_path = self.file_path
# TODO: use a better way to handle the file
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
self.temp_file = tempfile.NamedTemporaryFile()
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):

View File

@ -44,11 +44,13 @@ class QuestionClassifierNode(LLMNode):
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
variable = variable_pool.get(
node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = self._fetch_model_config(
node_data.model)
# fetch memory
memory = self._fetch_memory(
node_data_memory=node_data.memory,
@ -56,7 +58,8 @@ class QuestionClassifierNode(LLMNode):
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
node_data.instruction = variable_pool.convert_template(
node_data.instruction).text
files = (
self._fetch_files(
@ -178,12 +181,15 @@ class QuestionClassifierNode(LLMNode):
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
variable_template_parser = VariableTemplateParser(
template=node_data.instruction)
variable_selectors.extend(
variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
variable_mapping = {node_id + "." + key: value for key,
value in variable_mapping.items()}
return variable_mapping
@ -204,7 +210,8 @@ class QuestionClassifierNode(LLMNode):
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_template = self._get_prompt_template(
node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
@ -217,13 +224,15 @@ class QuestionClassifierNode(LLMNode):
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_config.model_schema.model_properties.get(
ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
@ -264,7 +273,8 @@ class QuestionClassifierNode(LLMNode):
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(
histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
@ -305,4 +315,5 @@ class QuestionClassifierNode(LLMNode):
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
raise InvalidModelTypeError(
f"Model mode {model_mode} not support.")

View File

@ -68,7 +68,8 @@ def test_executor_with_json_body_and_object_variable():
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
variable_pool.add(["pre_node_id", "object"], {
"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
@ -123,7 +124,8 @@ def test_executor_with_json_body_and_nested_object_variable():
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
variable_pool.add(["pre_node_id", "object"], {
"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(

View File

@ -18,6 +18,14 @@ from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_plain_text_to_dict():
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {
"aa": "bb", "cc": "dd"}
def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
@ -183,7 +191,8 @@ def test_http_request_node_form_with_file(monkeypatch):
def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}
assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")}
assert kwargs["files"] == {
"file": (None, b"test", "application/octet-stream")}
return httpx.Response(200, content=b"")
monkeypatch.setattr(