Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu
2025-08-19 14:59:06 +08:00
194 changed files with 6278 additions and 623 deletions

View File

@ -478,6 +478,13 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true
# Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200

View File

@ -968,6 +968,14 @@ class AccountConfig(BaseSettings):
)
class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1003,5 +1011,6 @@ class FeatureConfig(
HostedServiceConfig,
CeleryBeatConfig,
CeleryScheduleTasksConfig,
WorkflowLogConfig,
):
pass

View File

@ -1,3 +1,4 @@
import contextlib
import mimetypes
import os
import platform
@ -65,10 +66,8 @@ def guess_file_info_from_response(response: httpx.Response):
# Use python-magic to guess MIME type if still unknown or generic
if mimetype == "application/octet-stream" and magic is not None:
try:
with contextlib.suppress(magic.MagicException):
mimetype = magic.from_buffer(response.content[:1024], mime=True)
except magic.MagicException:
pass
extension = os.path.splitext(filename)[1]

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse
@ -24,7 +26,7 @@ class AnnotationReplyActionApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action):
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
@ -38,8 +40,6 @@ class AnnotationReplyActionApi(Resource):
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200

View File

@ -1,3 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -10,6 +12,8 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
@ -107,6 +111,121 @@ class RuleStructuredOutputGenerateApi(Resource):
return structured_output
class InstructionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
)
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]]
if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
return {"error": "incompatible parameters"}, 400
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
class InstructionGenerationTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self) -> dict:
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT}
case "code":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args['type']}")
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
api.add_resource(InstructionGenerateApi, "/instruction-generate")
api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")

View File

@ -1,7 +1,7 @@
import json
import logging
from argparse import ArgumentTypeError
from typing import cast
from typing import Literal, cast
from flask import request
from flask_login import current_user
@ -761,7 +761,7 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@ -787,8 +787,6 @@ class DocumentProcessingApi(DocumentResource):
document.paused_at = None
document.is_paused = False
db.session.commit()
else:
raise InvalidActionError()
return {"result": "success"}, 200
@ -843,7 +841,7 @@ class DocumentStatusApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action):
def post(self, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -39,7 +39,7 @@ class UploadFileApi(Resource):
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
@ -15,7 +17,7 @@ from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action):
def post(self, app_model: App, action: Literal["enable", "disable"]):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
@ -25,8 +27,6 @@ class AnnotationReplyActionApi(Resource):
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask import request
from flask_restful import marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@ -358,14 +360,14 @@ class DatasetApi(DatasetApiResource):
class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations."""
def patch(self, tenant_id, dataset_id, action):
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
"""
Batch update document status.
Args:
tenant_id: tenant id
dataset_id: dataset id
action: action to perform (enable, disable, archive, un_archive)
action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
Returns:
dict: A dictionary with a key 'result' and a value 'success'

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
@ -77,7 +79,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action):
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -181,7 +181,7 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@ -1,6 +1,7 @@
import json
import logging
import re
from collections.abc import Sequence
from typing import Optional, cast
import json_repair
@ -11,6 +12,8 @@ from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from models import App, Message, WorkflowNodeExecutionModel, db
class LLMGenerator:
@ -388,3 +394,181 @@ class LLMGenerator:
except Exception as e:
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict:
app: App | None = db.session.query(App).where(App.id == flow_id).first()
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
if not last_run:
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
current=current,
error_message="",
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
@staticmethod
def instruction_modify_workflow(
tenant_id: str,
flow_id: str,
node_id: str,
current: str,
instruction: str,
model_config: dict,
ideal_output: str | None,
) -> dict:
from services.workflow_service import WorkflowService
app: App | None = db.session.query(App).where(App.id == flow_id).first()
if not app:
raise ValueError("App not found.")
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
raise ValueError("Workflow not found for the given app model.")
last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
try:
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
except Exception:
try:
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
"type"
]
except Exception:
node_type = "llm"
if not last_run: # Node is not executed yet
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
current=current,
error_message="",
instruction=instruction,
node_type=node_type,
ideal_output=ideal_output,
)
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
if not raw_agent_log:
return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
def dict_of_event(event: AgentLogEvent) -> dict:
return {
"status": event.status,
"error": event.error,
"data": event.data,
}
return [dict_of_event(event) for event in parsed]
last_run_dict = {
"inputs": last_run.inputs_dict,
"status": last_run.status,
"error": last_run.error,
"agent_log": agent_log_of(last_run),
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=last_run.error,
instruction=instruction,
node_type=last_run.node_type,
ideal_output=ideal_output,
)
@staticmethod
def __instruction_modify_common(
tenant_id: str,
model_config: dict,
last_run: dict | None,
current: str | None,
error_message: str | None,
instruction: str,
node_type: str,
ideal_output: str | None,
) -> dict:
LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}"
ERROR_MESSAGE = "{{#error_message#}}"
injected_instruction = instruction
if LAST_RUN in injected_instruction:
injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run))
if CURRENT in injected_instruction:
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
if ERROR_MESSAGE in injected_instruction:
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
match node_type:
case "llm", "agent":
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
case "code":
system_prompt = LLM_MODIFY_CODE_SYSTEM
case _:
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
prompt_messages = [
SystemPromptMessage(content=system_prompt),
UserPromptMessage(
content=json.dumps(
{
"current": current,
"last_run": last_run,
"instruction": injected_instruction,
"ideal_output": ideal_output,
}
)
),
]
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
generated_raw = cast(str, response.message.content)
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@ -309,3 +309,116 @@ eg:
Here is the JSON schema:
{{schema}}
""" # noqa: E501
LLM_MODIFY_PROMPT_SYSTEM = """
Both your input and output should be in JSON format.
! Below is the schema for input content !
{
"type": "object",
"description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.",
"properties": {
"current": {
"type": "string",
"description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed."
},
"last_run": {
"type": "object",
"description": "The output result from the large language model after receiving the prompt.",
},
"instruction": {
"type": "string",
"description": "User's instruction to edit the current prompt"
},
"ideal_output": {
"type": "string",
"description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal."
}
}
}
! Above is the schema for input content !
! Below is the schema for output content !
{
"type": "object",
"description": "Your feedback to the user after they provide modification suggestions.",
"properties": {
"modified": {
"type": "string",
"description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change"
},
"message": {
"type": "string",
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
}
},
"required": [
"modified",
"message"
]
}
! Above is the schema for output content !
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
""" # noqa: E501
LLM_MODIFY_CODE_SYSTEM = """
Both your input and output should be in JSON format.
! Below is the schema for input content !
{
"type": "object",
"description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.",
"properties": {
"current": {
"type": "string",
"description": "The code before modification."
},
"last_run": {
"type": "object",
"description": "The result of the code.",
},
"message": {
"type": "string",
"description": "User's instruction to edit the current code"
}
}
}
! Above is the schema for input content !
! Below is the schema for output content !
{
"type": "object",
"description": "Your feedback to the user after they provide modification suggestions.",
"properties": {
"modified": {
"type": "string",
"description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change"
},
"message": {
"type": "string",
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
}
},
"required": [
"modified",
"message"
]
}
! Above is the schema for output content !
When you are modifying the code, you should remember:
- Do not use print, this not work in dify sandbox.
- Do not try dangerous call like deleting files. It's PROHIBITED.
- Do not use any library that is not built-in in with Python.
- Get inputs from the parameters of the function and have explicit type annotations.
- Write proper imports at the top of the code.
- Use return statement to return the result.
- You should return a `dict`. If you need to return a `result: str`, you should `return {"result": result}`.
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
""" # noqa: E501
INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}.
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""

View File

@ -4,8 +4,8 @@ import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient # type: ignore
from sqlalchemy import JSON, Column, String, func
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
from sqlalchemy import JSON, Column, String
from sqlalchemy.dialects.mysql import LONGTEXT
from configs import dify_config
@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector):
)
try:
if self._hybrid_search_enabled:
self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name}
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""")
self._client.create_fts_idx_with_fts_index_param(
table_name=self._collection_name,
fts_idx_param=FtsIndexParam(
index_name="fulltext_index_for_col_text",
field_names=["text"],
parser_type=FtsParser.IK,
),
)
except Exception as e:
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
+ "to support fulltext index and vector index in the same table",
e,
)
self._client.refresh_metadata([self._collection_name])
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector):
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
distance_func=l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=_where_clause,

View File

@ -331,6 +331,12 @@ class QdrantVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score_threshold >= 1:
# return empty list because some versions of qdrant may response with 400 bad request
# and at the same time, the score_threshold with value 1 may be valid for other vector stores
return []
filter = models.Filter(
must=[
models.FieldCondition(
@ -355,7 +361,7 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=float(kwargs.get("score_threshold") or 0.0),
score_threshold=score_threshold,
)
docs = []
for result in results:
@ -363,7 +369,6 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(

View File

@ -145,13 +145,19 @@ def init_app(app: DifyApp) -> Celery:
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")
beat_schedule["check_upgradable_plugin_task"] = {
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
"schedule": crontab(minute="*/15"),
}
if dify_config.WORKFLOW_LOG_CLEANUP_ENABLED:
# 2:00 AM every day
imports.append("schedule.clean_workflow_runlogs_precise")
beat_schedule["clean_workflow_runlogs_precise"] = {
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@ -205,7 +205,7 @@ vdb = [
"pgvector==0.2.5",
"pymilvus~=2.5.0",
"pymochow==1.3.1",
"pyobvector~=0.1.6",
"pyobvector~=0.2.15",
"qdrant-client==1.9.0",
"tablestore==6.2.0",
"tcvectordb~=1.6.4",

View File

@ -0,0 +1,155 @@
import datetime
import logging
import time
import click
import app
from configs import dify_config
from extensions.ext_database import db
from models.model import (
AppAnnotationHitHistory,
Conversation,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.workflow import ConversationVariable, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
_logger = logging.getLogger(__name__)
MAX_RETRIES = 3
BATCH_SIZE = dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE
@app.celery.task(queue="dataset")
def clean_workflow_runlogs_precise():
"""Clean expired workflow run logs with retry mechanism and complete message cascade"""
click.echo(click.style("Start clean workflow run logs (precise mode with complete cascade).", fg="green"))
start_at = time.perf_counter()
retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)
try:
total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
if total_workflow_runs == 0:
_logger.info("No expired workflow run logs found")
return
_logger.info("Found %s expired workflow run logs to clean", total_workflow_runs)
total_deleted = 0
failed_batches = 0
batch_count = 0
while True:
workflow_runs = (
db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all()
)
if not workflow_runs:
break
workflow_run_ids = [run.id for run in workflow_runs]
batch_count += 1
success = _delete_batch_with_retry(workflow_run_ids, failed_batches)
if success:
total_deleted += len(workflow_run_ids)
failed_batches = 0
else:
failed_batches += 1
if failed_batches >= MAX_RETRIES:
_logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES)
break
else:
# Calculate incremental delay times: 5, 10, 15 minutes
retry_delay_minutes = failed_batches * 5
_logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes)
time.sleep(retry_delay_minutes * 60)
continue
_logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted)
except Exception as e:
db.session.rollback()
_logger.exception("Unexpected error in workflow log cleanup")
raise
end_at = time.perf_counter()
execution_time = end_at - start_at
click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green"))
def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> bool:
"""Delete a single batch with a retry mechanism and complete cascading deletion"""
try:
with db.session.begin_nested():
message_data = (
db.session.query(Message.id, Message.conversation_id)
.filter(Message.workflow_run_id.in_(workflow_run_ids))
.all()
)
message_id_list = [msg.id for msg in message_data]
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
if message_id_list:
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.message_id.in_(message_id_list)
).delete(synchronize_session=False)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
db.session.query(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
).delete(synchronize_session=False)
if conversation_id_list:
db.session.query(ConversationVariable).where(
ConversationVariable.conversation_id.in_(conversation_id_list)
).delete(synchronize_session=False)
db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
synchronize_session=False
)
db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
db.session.commit()
return True
except Exception as e:
db.session.rollback()
_logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1)
return False

View File

@ -293,7 +293,7 @@ class AppAnnotationService:
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
# Step 2: Bulk delete hit histories in a single query
db.session.query(AppAnnotationHitHistory).filter(
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
).delete(synchronize_session=False)
@ -307,7 +307,7 @@ class AppAnnotationService:
# Step 4: Bulk delete annotations in a single query
deleted_count = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.id.in_(annotation_ids_to_delete))
.where(MessageAnnotation.id.in_(annotation_ids_to_delete))
.delete(synchronize_session=False)
)
@ -505,9 +505,9 @@ class AppAnnotationService:
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id)
annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
for annotation in annotations_query.yield_per(100):
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter(
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id == annotation.id
)
for annotation_hit_history in annotation_hit_histories_query.yield_per(100):

View File

@ -6,7 +6,7 @@ import secrets
import time
import uuid
from collections import Counter
from typing import Any, Optional
from typing import Any, Literal, Optional
from flask_login import current_user
from sqlalchemy import func, select
@ -55,7 +55,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.account import NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
@ -2231,14 +2231,16 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid")
@staticmethod
def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user):
def batch_update_document_status(
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
):
"""
Batch update document status.
Args:
dataset (Dataset): The dataset object
document_ids (list[str]): List of document IDs to update
action (str): Action to perform (enable, disable, archive, un_archive)
action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform
user: Current user performing the action
Raises:
@ -2321,9 +2323,10 @@ class DocumentService:
raise propagation_error
@staticmethod
def _prepare_document_status_update(document, action: str, user):
"""
Prepare document status update information.
def _prepare_document_status_update(
document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user
):
"""Prepare document status update information.
Args:
document: Document object to update
@ -2786,7 +2789,9 @@ class SegmentService:
db.session.commit()
@classmethod
def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
@ -2844,8 +2849,6 @@ class SegmentService:
db.session.commit()
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
else:
raise InvalidActionError()
@classmethod
def create_child_chunk(

View File

@ -1,5 +1,6 @@
import logging
import time
from typing import Literal
import click
from celery import shared_task # type: ignore
@ -13,7 +14,7 @@ from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: str):
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
"""
Async deal dataset from index
:param dataset_id: dataset_id

View File

@ -1,4 +1,5 @@
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest):
),
)
def search_by_vector(self):
super().search_by_vector()
# only test for qdrant, may not work on other vector stores
hits_by_vector: list[Document] = self.vector.search_by_vector(
query_vector=self.example_embedding, score_threshold=1
)
assert len(hits_by_vector) == 0
def test_qdrant_vector(setup_mock_redis):
QdrantVectorTest().run_all_tests()

View File

@ -471,7 +471,7 @@ class TestAnnotationService:
# Verify annotation was deleted
from extensions.ext_database import db
deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
assert deleted_annotation is None
# Verify delete_annotation_index_task was called (when annotation setting exists)
@ -1175,7 +1175,7 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
# Verify annotation was deleted
deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
assert deleted_annotation is None
# Verify delete_annotation_index_task was called

View File

@ -234,7 +234,7 @@ class TestAPIBasedExtensionService:
# Verify extension was deleted
from extensions.ext_database import db
deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first()
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
assert deleted_extension is None
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):

File diff suppressed because it is too large Load Diff

View File

@ -484,7 +484,7 @@ class TestMessageService:
# Verify feedback was deleted
from extensions.ext_database import db
deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first()
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
assert deleted_feedback is None
def test_create_feedback_no_rating_when_not_exists(

View File

@ -469,6 +469,6 @@ class TestModelLoadBalancingService:
# Verify inherit config was created in database
inherit_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all()
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
)
assert len(inherit_configs) == 1

19
api/uv.lock generated
View File

@ -1602,7 +1602,7 @@ vdb = [
{ name = "pgvector", specifier = "==0.2.5" },
{ name = "pymilvus", specifier = "~=2.5.0" },
{ name = "pymochow", specifier = "==1.3.1" },
{ name = "pyobvector", specifier = "~=0.1.6" },
{ name = "pyobvector", specifier = "~=0.2.15" },
{ name = "qdrant-client", specifier = "==1.9.0" },
{ name = "tablestore", specifier = "==6.2.0" },
{ name = "tcvectordb", specifier = "~=1.6.4" },
@ -4569,17 +4569,19 @@ wheels = [
[[package]]
name = "pyobvector"
version = "0.1.14"
version = "0.2.15"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiomysql" },
{ name = "numpy" },
{ name = "pydantic" },
{ name = "pymysql" },
{ name = "sqlalchemy" },
{ name = "sqlglot" },
]
sdist = { url = "https://files.pythonhosted.org/packages/dc/59/7d762061808948dd6aad165a000b34e22163dc83fb5014184eeacc0fabe5/pyobvector-0.1.14.tar.gz", hash = "sha256:4f85cdd63064d040e94c0a96099a0cd5cda18ce625865382e89429f28422fc02", size = 26780, upload-time = "2024-11-20T11:46:18.017Z" }
sdist = { url = "https://files.pythonhosted.org/packages/0b/7d/3f3aac6acf1fdd1782042d6eecd48efaa2ee355af0dbb61e93292d629391/pyobvector-0.2.15.tar.gz", hash = "sha256:5de258c1e952c88b385b5661e130c1cf8262c498c1f8a4a348a35962d379fce4", size = 39611, upload-time = "2025-08-18T02:49:26.683Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/88/68/ecb21b74c974e7be7f9034e205d08db62d614ff5c221581ae96d37ef853e/pyobvector-0.1.14-py3-none-any.whl", hash = "sha256:828e0bec49a177355b70c7a1270af3b0bf5239200ee0d096e4165b267eeff97c", size = 35526, upload-time = "2024-11-20T11:46:16.809Z" },
{ url = "https://files.pythonhosted.org/packages/5f/1f/a62754ba9b8a02c038d2a96cb641b71d3809f34d2ba4f921fecd7840d7fb/pyobvector-0.2.15-py3-none-any.whl", hash = "sha256:feeefe849ee5400e72a9a4d3844e425a58a99053dd02abe06884206923065ebb", size = 52680, upload-time = "2025-08-18T02:49:25.452Z" },
]
[[package]]
@ -5432,6 +5434,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
]
[[package]]
name = "sqlglot"
version = "26.33.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/25/9d/fcd59b4612d5ad1e2257c67c478107f073b19e1097d3bfde2fb517884416/sqlglot-26.33.0.tar.gz", hash = "sha256:2817278779fa51d6def43aa0d70690b93a25c83eb18ec97130fdaf707abc0d73", size = 5353340, upload-time = "2025-07-01T13:09:06.311Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/31/8d/f1d9cb5b18e06aa45689fbeaaea6ebab66d5f01d1e65029a8f7657c06be5/sqlglot-26.33.0-py3-none-any.whl", hash = "sha256:031cee20c0c796a83d26d079a47fdce667604df430598c7eabfa4e4dfd147033", size = 477610, upload-time = "2025-07-01T13:09:03.926Z" },
]
[[package]]
name = "sseclient-py"
version = "1.8.0"