Compare commits

..

12 Commits

375 changed files with 4623 additions and 20557 deletions

View File

@ -4,7 +4,8 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "build/feat/hitl"
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
@ -13,7 +14,10 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1

View File

@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@ -553,8 +553,6 @@ WORKFLOW_LOG_CLEANUP_ENABLED=false
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# Comma-separated list of workflow IDs to clean logs for
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS=
# App configuration
APP_MAX_EXECUTION_TIME=1200
@ -717,7 +715,6 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL=200
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000

View File

@ -52,12 +52,14 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@ -102,7 +104,10 @@ forbidden_modules =
core.schemas
core.tools
core.trigger
core.variables
ignore_imports =
core.workflow.nodes.agent.agent_node -> core.db.session_factory
core.workflow.nodes.agent.agent_node -> models.tools
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability
@ -123,6 +128,11 @@ ignore_imports =
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.file.models
@ -143,6 +153,7 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
@ -158,6 +169,9 @@ ignore_imports =
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
@ -172,9 +186,11 @@ ignore_imports =
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
core.workflow.nodes.list_operator.node -> core.file
core.workflow.nodes.llm.file_saver -> core.file
core.workflow.nodes.llm.llm_utils -> core.variables.segments
core.workflow.nodes.llm.node -> core.file
core.workflow.nodes.llm.node -> core.file.file_manager
core.workflow.nodes.llm.node -> core.file.models
core.workflow.nodes.loop.entities -> core.variables.types
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
core.workflow.nodes.protocols -> core.file
core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
@ -189,14 +205,12 @@ ignore_imports =
core.workflow.utils.condition.processor -> core.file.file_manager
core.workflow.workflow_entry -> core.file.models
core.workflow.workflow_type_encoder -> core.file.models
core.workflow.variables.segments -> core.file
core.workflow.variables.types -> core.file.models
core.workflow.variables.variables -> core.helper.encrypter
core.workflow.nodes.agent.agent_node -> models.model
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
@ -206,6 +220,7 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
@ -221,15 +236,67 @@ ignore_imports =
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
core.workflow.nodes.llm.file_saver -> core.tools.signature
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
core.workflow.nodes.tool.tool_node -> core.tools.errors
core.workflow.conversation_variable_updater -> core.variables
core.workflow.graph_engine.entities.commands -> core.variables.variables
core.workflow.nodes.agent.agent_node -> core.variables.segments
core.workflow.nodes.answer.answer_node -> core.variables
core.workflow.nodes.code.code_node -> core.variables.segments
core.workflow.nodes.code.code_node -> core.variables.types
core.workflow.nodes.code.entities -> core.variables.types
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
core.workflow.nodes.document_extractor.node -> core.variables
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
core.workflow.nodes.http_request.node -> core.variables.segments
core.workflow.nodes.human_input.entities -> core.variables.consts
core.workflow.nodes.iteration.iteration_node -> core.variables
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
core.workflow.nodes.list_operator.node -> core.variables
core.workflow.nodes.list_operator.node -> core.variables.segments
core.workflow.nodes.llm.node -> core.variables
core.workflow.nodes.loop.loop_node -> core.variables
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
core.workflow.nodes.tool.tool_node -> core.variables.segments
core.workflow.nodes.tool.tool_node -> core.variables.variables
core.workflow.nodes.trigger_webhook.node -> core.variables.types
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
core.workflow.nodes.variable_assigner.v1.node -> core.variables
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
core.workflow.nodes.variable_assigner.v2.node -> core.variables
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
core.workflow.runtime.read_only_wrappers -> core.variables.segments
core.workflow.runtime.variable_pool -> core.variables
core.workflow.runtime.variable_pool -> core.variables.consts
core.workflow.runtime.variable_pool -> core.variables.segments
core.workflow.runtime.variable_pool -> core.variables.variables
core.workflow.utils.condition.processor -> core.variables
core.workflow.utils.condition.processor -> core.variables.segments
core.workflow.variable_loader -> core.variables
core.workflow.variable_loader -> core.variables.consts
core.workflow.workflow_type_encoder -> core.variables
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
@ -286,6 +353,7 @@ forbidden_modules =
core.schemas
core.tools
core.trigger
core.variables
core.workflow
ignore_imports =
core.model_runtime.model_providers.__base.ai_model -> configs

View File

@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,workflow_based_app_execution,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
]
}
]

View File

@ -122,7 +122,7 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@ -1180,16 +1180,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
default=0,
)
# API token last_used_at batch update
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
description="Enable periodic batch update of API token last_used_at timestamps",
default=True,
)
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
description="Interval in minutes for batch updating API token last_used_at (default 30)",
default=30,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",
@ -1314,9 +1304,6 @@ class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
)
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS: str = Field(
default="", description="Comma-separated list of workflow IDs to clean logs for"
)
class SwaggerUIConfig(BaseSettings):
@ -1347,10 +1334,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
description="Maximum interval in milliseconds between batches",
default=200,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,

View File

@ -259,20 +259,11 @@ class CeleryConfig(DatabaseConfig):
description="Password of the Redis Sentinel master.",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
)
CELERY_TASK_ANNOTATIONS: dict[str, Any] | None = Field(
description=(
"Annotations for Celery tasks as a JSON mapping of task name -> options "
"(for example, rate limits or other task-specific settings)."
),
default=None,
)
@computed_field
def CELERY_RESULT_BACKEND(self) -> str | None:
if self.CELERY_BACKEND in ("database", "rabbitmq"):

View File

@ -21,7 +21,6 @@ language_timezone_mapping = {
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
"nl-NL": "Europe/Amsterdam",
}
languages = list(language_timezone_mapping.keys())

View File

@ -5,6 +5,8 @@ from enum import StrEnum
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
from controllers.console import console_ns
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -22,9 +24,6 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
def get_or_create_model(model_name: str, field_def):
# Import lazily to avoid circular imports between console controllers and schema helpers.
from controllers.console import console_ns
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)

View File

@ -10,7 +10,6 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@ -599,12 +599,7 @@ def _get_conversation(app_model, conversation_id):
db.session.execute(
sa.update(Conversation)
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
# Keep updated_at unchanged when only marking a conversation as read.
.values(
read_at=naive_utc_now(),
read_account_id=current_user.id,
updated_at=Conversation.updated_at,
)
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
)
db.session.commit()
db.session.refresh(conversation)

View File

@ -16,10 +16,10 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.variables.segment_group import SegmentGroup
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type

View File

@ -463,9 +463,8 @@ class WorkflowRunNodeExecutionListApi(Resource):
class ConsoleWorkflowPauseDetailsApi(Resource):
"""Console API for getting workflow pause details."""
@setup_required
@login_required
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow pause details.
@ -478,14 +477,10 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Query WorkflowRun to determine if workflow is suspended
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
if not workflow_run:
raise NotFoundError("Workflow run not found")
if workflow_run.tenant_id != current_user.current_tenant_id:
raise NotFoundError("Workflow run not found")
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:

View File

@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
# Register models for flask_restx to avoid dict type issues in Swagger
@ -821,11 +820,6 @@ class DatasetApiDeleteApi(Resource):
if key is None:
console_ns.abort(404, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type

View File

@ -1,7 +1,6 @@
import urllib.parse
import httpx
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
@ -11,12 +10,12 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.console import console_ns
from controllers.fastopenapi import console_router
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant, login_required
from libs.login import current_account_with_tenant
from services.file_service import FileService
@ -24,73 +23,69 @@ class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
@console_ns.route("/remote-files/<path:url>")
class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
@console_router.get(
"/remote-files/<path:url>",
response_model=RemoteFileInfo,
tags=["console"],
)
def get_remote_file_info(url: str) -> RemoteFileInfo:
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
@console_router.post(
"/remote-files/upload",
response_model=FileWithSignedUrl,
tags=["console"],
status_code=201,
)
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
url = payload.url
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
).model_dump(mode="json")
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)
@console_ns.route("/remote-files/upload")
class RemoteFileUpload(Resource):
@login_required
def post(self):
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = payload.url
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
# Try to fetch remote file metadata/content first
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
file_info = helpers.guess_file_info_from_response(resp)
# Enforce file size limit with 400 (Bad Request) per tests' expectation
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
# Success: return created resource with 201 status
return (
FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
).model_dump(mode="json"),
201,
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)

View File

@ -42,15 +42,7 @@ class SetupResponse(BaseModel):
tags=["console"],
)
def get_setup_status_api() -> SetupStatusResponse:
"""Get system setup status.
NOTE: This endpoint is unauthenticated by design.
During first-time bootstrap there is no admin account yet, so frontend initialization must be
able to query setup progress before any login flow exists.
Only bootstrap-safe status information should be returned by this endpoint.
"""
"""Get system setup status."""
if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status()
if setup_status and not isinstance(setup_status, bool):
@ -69,12 +61,7 @@ def get_setup_status_api() -> SetupStatusResponse:
)
@only_edition_self_hosted
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
"""Initialize system setup with admin account.
NOTE: This endpoint is unauthenticated by design for first-time bootstrap.
Access is restricted by deployment mode (`SELF_HOSTED`), one-time setup guards,
and init-password validation rather than user session authentication.
"""
"""Initialize system setup with admin account."""
if get_setup_status():
raise AlreadySetupError()

View File

@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
TagService.delete_tag(tag_id)
return "", 204
return 204
@console_ns.route("/tag-bindings/create")

View File

@ -34,8 +34,6 @@ from .dataset import (
metadata,
segment,
)
from .dataset.rag_pipeline import rag_pipeline_workflow
from .end_user import end_user
from .workspace import models
__all__ = [
@ -46,7 +44,6 @@ __all__ = [
"conversation",
"dataset",
"document",
"end_user",
"file",
"file_preview",
"hit_testing",
@ -54,7 +51,6 @@ __all__ = [
"message",
"metadata",
"models",
"rag_pipeline_workflow",
"segment",
"site",
"workflow",

View File

@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return "", 204
return 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
return "", 204
return 204
@service_api_ns.route("/datasets/tags/binding")
@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return "", 204
return 204
@service_api_ns.route("/datasets/tags/unbinding")
@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return "", 204
return 204
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")

View File

@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return "", 204
return 204

View File

@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return "", 204
return 204
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")

View File

@ -1,3 +1,5 @@
import string
import uuid
from collections.abc import Generator
from typing import Any
@ -10,7 +12,6 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
@ -40,7 +41,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@ -75,7 +76,7 @@ class DatasourcePluginsApi(DatasetApiResource):
return datasource_plugins, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@ -130,7 +131,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@ -231,4 +232,12 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return serialize_upload_file(upload_file), 201
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -1,22 +0,0 @@
"""
Serialization helpers for Service API knowledge pipeline endpoints.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import UploadFile
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
}

View File

@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return "", 204
return 204
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return "", 204
return 204
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")

View File

@ -1,3 +0,0 @@
from . import end_user
__all__ = ["end_user"]

View File

@ -1,41 +0,0 @@
from uuid import UUID
from flask_restx import Resource
from controllers.service_api import service_api_ns
from controllers.service_api.end_user.error import EndUserNotFoundError
from controllers.service_api.wraps import validate_app_token
from fields.end_user_fields import EndUserDetail
from models.model import App
from services.end_user_service import EndUserService
@service_api_ns.route("/end-users/<uuid:end_user_id>")
class EndUserApi(Resource):
"""Resource for retrieving end user details by ID."""
@service_api_ns.doc("get_end_user")
@service_api_ns.doc(description="Get an end user by ID")
@service_api_ns.doc(
params={"end_user_id": "End user ID"},
responses={
200: "End user retrieved successfully",
401: "Unauthorized - invalid API token",
404: "End user not found",
},
)
@validate_app_token
def get(self, app_model: App, end_user_id: UUID):
"""Get end user detail.
This endpoint is scoped to the current app token's tenant/app to prevent
cross-tenant/app access when an end-user ID is known.
"""
end_user = EndUserService.get_end_user_by_id(
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=str(end_user_id)
)
if end_user is None:
raise EndUserNotFoundError()
return EndUserDetail.model_validate(end_user).model_dump(mode="json")

View File

@ -1,7 +0,0 @@
from libs.exception import BaseHTTPException
class EndUserNotFoundError(BaseHTTPException):
error_code = "end_user_not_found"
description = "End user not found."
code = 404

View File

@ -1,24 +1,27 @@
import logging
import time
from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
from typing import Concatenate, ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
@ -217,8 +220,6 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
@ -255,18 +256,12 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.first()
)
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
@ -301,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token with Redis caching.
This function uses a two-tier approach:
1. First checks Redis cache for the token
2. If not cached, queries database and caches the result
The last_used_at field is updated asynchronously via Celery task
to avoid blocking the request.
Validate and get API token.
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
@ -320,18 +308,29 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
# Try to get token from cache first
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token validation served from cache for scope: %s", scope)
# Record usage in Redis for later batch update (no Celery task per request)
record_token_usage(auth_token, scope)
return cast(ApiToken, cached_token)
current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt)
api_token = session.scalar(stmt)
# Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently
return fetch_token_with_single_flight(auth_token, scope)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
if not api_token:
raise Unauthorized("Access token is invalid")
return api_token
class DatasetApiResource(Resource):

View File

@ -65,12 +65,15 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
# TODO(QuantumGhost): disable authorization for web app
# form api temporarily
@web_ns.route("/form/human_input/<string:form_token>")
# class HumanInputFormApi(WebApiResource):
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""
# NOTE(QuantumGhost): this endpoint is unauthenticated on purpose for now.
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
def get(self, form_token: str):
"""

View File

@ -25,6 +25,7 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -33,7 +34,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.variables.variables import Variable
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client

View File

@ -50,6 +50,7 @@ from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
@ -61,7 +62,6 @@ from core.workflow.enums import (
)
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
@ -346,7 +346,7 @@ class WorkflowResponseConverter:
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED,
status=WorkflowExecutionStatus.PAUSED.value,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
@ -422,7 +422,7 @@ class WorkflowResponseConverter:
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
status=workflow_run.status.value,
outputs=encoded_outputs,
error=workflow_run.error,
elapsed_time=elapsed_time,

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
@ -20,7 +21,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline

View File

@ -34,7 +34,7 @@ def stream_topic_events(
on_subscribe()
while True:
try:
msg = sub.receive(timeout=1)
msg = sub.receive(timeout=0.1)
except SubscriptionClosedError:
return
if msg is None:

View File

@ -262,7 +262,7 @@ class WorkflowPauseStreamResponse(StreamResponse):
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: WorkflowExecutionStatus
status: str
created_at: int
elapsed_time: float
total_tokens: int

View File

@ -1,12 +1,12 @@
import logging
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.variables import VariableBase
logger = logging.getLogger(__name__)

View File

@ -45,8 +45,6 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.file import helpers as file_helpers
from core.file.enums import FileTransferMethod
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
@ -58,11 +56,10 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@ -466,85 +463,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
metadata=metadata_dict,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@ -64,13 +64,7 @@ class MessageCycleManager:
# Use SQLAlchemy 2.x style session.scalar(select(...))
with session_factory.create_session() as session:
message_file = session.scalar(
select(MessageFile)
.where(
MessageFile.message_id == message_id,
)
.where(MessageFile.belongs_to == "assistant")
)
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
if message_file:
self._message_has_file.add(message_id)

View File

@ -8,7 +8,6 @@ from core.file.file_manager import file_manager
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
@ -17,7 +16,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
@ -77,7 +75,6 @@ class DifyNodeFactory(NodeFactory):
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._rag_retrieval = DatasetRetrieval()
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@ -143,15 +140,6 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
from core.workflow.variables.utils import dumps_with_segments
from core.variables.utils import dumps_with_segments
class TemplateTransformer(ABC):

View File

@ -1,15 +1,13 @@
import json
import logging
import math
import re
import threading
import time
from collections import Counter, defaultdict
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@ -20,7 +18,6 @@ from core.app.app_config.entities import (
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import session_factory
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
@ -61,30 +58,12 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.repositories.rag_retrieval_protocol import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
from models.dataset import (
ChildChunk,
Dataset,
DatasetMetadata,
DatasetQuery,
DocumentSegment,
RateLimitLog,
SegmentAttachmentBinding,
)
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
@ -94,8 +73,6 @@ default_retrieval_model: dict[str, Any] = {
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class DatasetRetrieval:
def __init__(self, application_generate_entity=None):
@ -114,233 +91,6 @@ class DatasetRetrieval:
else:
self._llm_usage = self._llm_usage.plus(usage)
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
self._check_knowledge_rate_limit(request.tenant_id)
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
available_datasets_ids = [i.id for i in available_datasets]
if not available_datasets_ids:
return []
if not request.query:
return []
metadata_filter_document_ids, metadata_condition = None, None
if request.metadata_filtering_mode != "disabled":
# Convert workflow layer types to app_config layer types
if not request.metadata_model_config:
raise ValueError("metadata_model_config is required for this method")
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
app_metadata_filtering_conditions = None
if request.metadata_filtering_conditions is not None:
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
request.metadata_filtering_conditions.model_dump()
)
query = request.query if request.query is not None else ""
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
dataset_ids=available_datasets_ids,
query=query,
tenant_id=request.tenant_id,
user_id=request.user_id,
metadata_filtering_mode=request.metadata_filtering_mode,
metadata_model_config=app_metadata_model_config,
metadata_filtering_conditions=app_metadata_filtering_conditions,
inputs={},
)
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
planning_strategy = PlanningStrategy.REACT_ROUTER
# Ensure required fields are not None for single retrieval mode
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
provider=request.model_provider,
model=request.model_name,
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=request.model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise exc.ModelCredentialsNotInitializedError(
f"Model {request.model_name} credentials is not initialized."
)
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
stop = []
completion_params = (request.completion_params or {}).copy()
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
if not model_schema:
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
model_config = ModelConfigWithCredentialsEntity(
provider=request.model_provider,
model=request.model_name,
model_schema=model_schema,
mode=request.model_mode or "chat",
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
all_documents = self.single_retrieve(
request.app_id,
request.tenant_id,
request.user_id,
request.user_from,
request.query,
available_datasets,
model_instance,
model_config,
planning_strategy,
None, # message_id
metadata_filter_document_ids,
metadata_condition,
)
else:
all_documents = self.multiple_retrieve(
app_id=request.app_id,
tenant_id=request.tenant_id,
user_id=request.user_id,
user_from=request.user_from,
available_datasets=available_datasets,
query=request.query,
top_k=request.top_k,
score_threshold=request.score_threshold,
reranking_mode=request.reranking_mode,
reranking_model=request.reranking_model,
weights=request.weights,
reranking_enable=request.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=request.attachment_ids,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from="workflow",
score=item.metadata.get("score"),
doc_metadata=item.metadata,
),
title=item.metadata.get("title"),
content=item.page_content,
)
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
dataset_ids = [i.segment.dataset_id for i in records]
document_ids = [i.segment.document_id for i in records]
with session_factory.create_session() as session:
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
if records:
for record in records:
segment = record.segment
dataset = dataset_map.get(segment.dataset_id)
document = document_map.get(segment.document_id)
if dataset and document:
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from="workflow",
score=record.score or 0.0,
segment_hit_count=segment.hit_count,
segment_word_count=segment.word_count,
segment_position=segment.position,
segment_index_node_hash=segment.index_node_hash,
doc_metadata=document.doc_metadata,
child_chunks=[
SourceChildChunk(
id=str(getattr(chunk, "id", "")),
content=str(getattr(chunk, "content", "")),
position=int(getattr(chunk, "position", 0)),
score=float(getattr(chunk, "score", 0.0)),
)
for chunk in (record.child_chunks or [])
],
position=None,
),
title=document.name,
files=list(record.files) if record.files else None,
content=segment.get_sign_content(),
)
if segment.answer:
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
if record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
def _score(item: Source) -> float:
meta = item.metadata
score = meta.score
if isinstance(score, (int, float)):
return float(score)
return 0.0
retrieval_resource_list = sorted(
retrieval_resource_list,
key=_score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item.metadata.position = position # type: ignore[index]
return retrieval_resource_list
def retrieve(
self,
app_id: str,
@ -400,7 +150,14 @@ class DatasetRetrieval:
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
available_datasets = []
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
@ -1404,6 +1161,7 @@ class DatasetRetrieval:
query=query or "",
)
result_text = ""
try:
# handle invoke result
invoke_result = cast(
@ -1434,8 +1192,7 @@ class DatasetRetrieval:
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
logger.warning(e, exc_info=True)
except Exception:
return None
return automatic_metadata_filters
@ -1649,12 +1406,7 @@ class DatasetRetrieval:
usage = None
for result in invoke_result:
text = result.delta.message.content
if isinstance(text, str):
full_text += text
elif isinstance(text, list):
for i in text:
if i.data:
full_text += i.data
full_text += text
if not model:
model = result.model
@ -1772,53 +1524,3 @@ class DatasetRetrieval:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
with session_factory.create_session() as session:
subquery = (
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
.where(
DocumentModel.indexing_status == "completed",
DocumentModel.enabled == True,
DocumentModel.archived == False,
DocumentModel.dataset_id.in_(dataset_ids),
)
.group_by(DocumentModel.dataset_id)
.having(func.count(DocumentModel.id) > 0)
.subquery()
)
results = (
session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
available_datasets = []
for dataset in results:
if not dataset:
continue
available_datasets.append(dataset)
return available_datasets
def _check_knowledge_rate_limit(self, tenant_id: str):
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with session_factory.create_session() as session:
rate_limit_log = RateLimitLog(
tenant_id=tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
raise exc.RateLimitExceededError(
"you have reached the knowledge base request rate limit of your subscription."
)

View File

@ -5,7 +5,7 @@ from collections.abc import Generator
from copy import deepcopy
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from models.model import File
from core.tools.__base.tool_runtime import ToolRuntime
@ -171,7 +171,7 @@ class Tool(ABC):
def create_file_message(self, file: File) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(file_marker="file_marker"),
message=ToolInvokeMessage.FileMessage(),
meta={"file": file},
)

View File

@ -3,8 +3,8 @@ from __future__ import annotations
import base64
import json
import logging
from collections.abc import Generator, Mapping
from typing import Any, cast
from collections.abc import Generator
from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
@ -17,7 +17,6 @@ from core.mcp.types import (
TextContent,
TextResourceContents,
)
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -47,7 +46,6 @@ class MCPTool(Tool):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self._latest_usage = LLMUsage.empty_usage()
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@ -61,10 +59,6 @@ class MCPTool(Tool):
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters)
# Extract usage metadata from MCP protocol's _meta field
self._latest_usage = self._derive_usage_from_result(result)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
@ -126,99 +120,6 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage:
"""
Extract usage metadata from MCP tool result's _meta field.
The MCP protocol's _meta field (aliased as 'meta' in Python) can contain
usage information such as token counts, costs, and other metadata.
Args:
result: The CallToolResult from MCP tool invocation
Returns:
LLMUsage instance with values from meta or empty_usage if not found
"""
# Extract usage from the meta field if present
if result.meta:
usage_dict = cls._extract_usage_dict(result.meta)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict))))
return LLMUsage.empty_usage()
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
"""
Recursively search for usage dictionary in the payload.
The MCP protocol's _meta field can contain usage data in various formats:
- Direct usage field: {"usage": {...}}
- Nested in metadata: {"metadata": {"usage": {...}}}
- Or nested within other fields
Args:
payload: The payload to search for usage data
Returns:
The usage dictionary if found, None otherwise
"""
# Check for direct usage field
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
# Check for metadata nested usage
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
# Check for common token counting fields directly in payload
# Some MCP servers may include token counts directly
if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload:
usage_dict: dict[str, Any] = {}
for key in (
"prompt_tokens",
"completion_tokens",
"total_tokens",
"prompt_unit_price",
"completion_unit_price",
"total_price",
"currency",
"prompt_price_unit",
"completion_price_unit",
"prompt_price",
"completion_price",
"latency",
"time_to_first_token",
"time_to_generate",
):
if key in payload:
usage_dict[key] = payload[key]
if usage_dict:
return usage_dict
# Recursively search through nested structures
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,

View File

@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.workflow.variables import VariableBase
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):

View File

@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeState
from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol
from .node_execution import NodeExecution
@ -237,6 +236,3 @@ class GraphExecution:
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1
_: GraphExecutionProtocol = GraphExecution(workflow_id="")

View File

@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.workflow.variables.variables import Variable
from core.variables.variables import Variable
class CommandType(StrEnum):

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Union, cast
from packaging.version import Version
from pydantic import ValidationError
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.db.session_factory import session_factory
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -26,6 +27,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -43,12 +45,17 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
MCPToolProvider,
WorkflowToolProvider,
)
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
llm_usage=llm_usage,
)
)
@staticmethod
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
provider_type_str = tool_config.get("type")
if provider_type_str:
return ToolProviderType(provider_type_str)
provider_id = tool_config.get("provider_name")
if not provider_id:
return ToolProviderType.BUILT_IN
with session_factory.create_session() as session:
provider_map: dict[
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
ToolProviderType,
] = {
WorkflowToolProvider: ToolProviderType.WORKFLOW,
MCPToolProvider: ToolProviderType.MCP,
ApiToolProvider: ToolProviderType.API,
BuiltinToolProvider: ToolProviderType.BUILT_IN,
}
for provider_model, provider_type in provider_map.items():
stmt = select(provider_model).where(
provider_model.id == provider_id,
provider_model.tenant_id == tenant_id,
)
if session.scalar(stmt):
return provider_type
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")

View File

@ -1,13 +1,13 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.variables import ArrayFileSegment, FileSegment, Segment
class AnswerNode(Node[AnswerNodeData]):

View File

@ -6,13 +6,13 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.variables.segments import ArrayFileSegment
from core.workflow.variables.types import SegmentType
from .exc import (
CodeNodeError,

View File

@ -3,9 +3,9 @@ from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.variables.types import SegmentType
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[

View File

@ -17,6 +17,8 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -24,8 +26,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment
from core.workflow.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile

View File

@ -23,11 +23,11 @@ from docx.text.paragraph import Paragraph
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables import ArrayFileSegment
from core.workflow.variables.segments import ArrayStringSegment, FileSegment
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError

View File

@ -14,8 +14,8 @@ from configs import dify_config
from core.file.enums import FileTransferMethod
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayFileSegment, FileSegment
from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (

View File

@ -8,6 +8,7 @@ from core.file import File, FileTransferMethod
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -15,7 +16,6 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.variables.segments import ArrayFileSegment
from factories import file_factory
from .entities import (

View File

@ -10,10 +10,10 @@ from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from core.workflow.variables.consts import SELECTORS_LENGTH
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit

View File

@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@ -33,9 +36,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from core.workflow.variables import IntegerVariable, NoneSegment
from core.workflow.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.variables.variables import Variable
from libs.datetime_utils import naive_utc_now
from .exc import (

View File

@ -20,7 +20,3 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""
class RateLimitExceededError(KnowledgeRetrievalNodeError):
"""Raised when the rate limit is exceeded."""

View File

@ -1,32 +1,70 @@
import json
import logging
import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from core.workflow.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.variables.segments import ArrayObjectSegment
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
RateLimitExceededError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)
if TYPE_CHECKING:
@ -35,6 +73,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"score_threshold_enabled": False,
}
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
@ -51,7 +97,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -63,7 +108,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -77,7 +121,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return "1"
def _run(self) -> NodeRunResult:
usage = LLMUsage.empty_usage()
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -85,7 +128,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
process_data={},
outputs={},
metadata={},
llm_usage=usage,
llm_usage=LLMUsage.empty_usage(),
)
variables: dict[str, Any] = {}
# extract variables
@ -113,9 +156,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
else:
variables["attachments"] = [variable.value]
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with sessionmaker(db.engine).begin() as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@ -128,17 +198,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
},
llm_usage=usage,
)
except RateLimitExceededError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node", exc_info=True)
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@ -148,7 +210,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@ -156,47 +217,92 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
retrieval_resource_list = []
metadata_filter_document_ids = None
metadata_condition = None
metadata_usage = LLMUsage.empty_usage()
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids),
)
.group_by(Document.dataset_id)
.having(func.count(Document.id) > 0)
.subquery()
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
# avoid blocking at retrieval
db.session.close()
for dataset in results:
# pass if dataset is not available
if not dataset:
continue
available_datasets.append(dataset)
if query:
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required for single retrieval mode")
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
raise ValueError("single_retrieval_config is required")
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
)
if model_schema:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
all_documents = dataset_retrieval.single_retrieve(
available_datasets=available_datasets,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
model_provider=model.provider,
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
)
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
@ -223,36 +329,284 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
},
}
case _:
# Handle any other reranking_mode values
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
"doc_metadata": item.metadata,
},
"title": item.metadata.get("title"),
"content": item.page_content,
}
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
"_source": "knowledge",
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
"child_chunks": [
{
"id": str(getattr(chunk, "id", "")),
"content": str(getattr(chunk, "content", "")),
"position": int(getattr(chunk, "position", 0)),
"score": float(getattr(chunk, "score", 0.0)),
}
for chunk in (record.child_chunks or [])
],
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
"segment_index_node_hash": segment.index_node_hash,
"doc_metadata": document.doc_metadata,
},
"title": document.name,
"files": list(record.files) if record.files else None,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
# Add summary if available
if record.summary:
source["summary"] = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=self._score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position # type: ignore[index]
return retrieval_resource_list, usage
def _score(self, item: dict[str, Any]) -> float:
meta = item.get("metadata")
if isinstance(meta, dict):
s = meta.get("score")
if isinstance(s, (int, float)):
return float(s)
return 0.0
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters: list[Any] = []
metadata_condition = None
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
case _:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception:
return [], usage
return automatic_metadata_filters, usage
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -272,3 +626,107 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ModelNotExistError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@ -2,11 +2,11 @@ from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError

View File

@ -14,10 +14,10 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation

View File

@ -49,6 +49,14 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
@ -69,14 +77,6 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile

View File

@ -3,9 +3,9 @@ from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.variables.types import SegmentType
_VALID_VAR_TYPE = frozenset(
[

View File

@ -6,6 +6,7 @@ from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
NodeExecutionType,
NodeType,
@ -30,7 +31,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from core.workflow.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now

View File

@ -8,9 +8,9 @@ from pydantic import (
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
from core.workflow.variables.types import SegmentType
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"

View File

@ -1,6 +1,6 @@
from typing import Any
from core.workflow.variables.types import SegmentType
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):

View File

@ -26,13 +26,13 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from core.workflow.variables.types import ArrayValidation, SegmentType
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData

View File

@ -12,6 +12,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -21,8 +23,6 @@ from core.workflow.enums import (
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.workflow.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile

View File

@ -3,13 +3,13 @@ from collections.abc import Mapping
from typing import Any
from core.file import FileTransferMethod
from core.variables.types import SegmentType
from core.variables.variables import FileVariable
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables.types import SegmentType
from core.workflow.variables.variables import FileVariable
from factories import file_factory
from factories.variable_factory import build_segment_with_type

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.types import SegmentType
class AdvancedSettings(BaseModel):

View File

@ -1,10 +1,10 @@
from collections.abc import Mapping
from core.variables.segments import Segment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData
from core.workflow.variables.segments import Segment
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):

View File

@ -3,9 +3,9 @@ from typing import Any, TypeVar
from pydantic import BaseModel
from core.workflow.variables import Segment
from core.workflow.variables.consts import SELECTORS_LENGTH
from core.workflow.variables.types import SegmentType
from core.variables import Segment
from core.variables.consts import SELECTORS_LENGTH
from core.variables.types import SegmentType
# Use double underscore (`__`) prefix for internal variables
# to minimize risk of collision with user-defined variable names.

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -8,7 +9,6 @@ from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.variables import SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode

View File

@ -1,6 +1,6 @@
from typing import Any
from core.workflow.variables import SegmentType
from core.variables import SegmentType
from .enums import Operation

View File

@ -2,14 +2,14 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.variables import SegmentType, VariableBase
from core.workflow.variables.consts import SELECTORS_LENGTH
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem

View File

@ -1,108 +0,0 @@
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.model_runtime.entities import LLMUsage
from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
from core.workflow.nodes.llm.entities import ModelConfig
class SourceChildChunk(BaseModel):
id: str = Field(default="", description="Child chunk ID")
content: str = Field(default="", description="Child chunk content")
position: int = Field(default=0, description="Child chunk position")
score: float = Field(default=0.0, description="Child chunk relevance score")
class SourceMetadata(BaseModel):
source: str = Field(
default="knowledge",
serialization_alias="_source",
description="Data source identifier",
)
dataset_id: str = Field(description="Dataset unique identifier")
dataset_name: str = Field(description="Dataset display name")
document_id: str = Field(description="Document unique identifier")
document_name: str = Field(description="Document display name")
data_source_type: str = Field(description="Type of data source")
segment_id: str | None = Field(default=None, description="Segment unique identifier")
retriever_from: str = Field(default="workflow", description="Retriever source context")
score: float = Field(default=0.0, description="Retrieval relevance score")
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
segment_position: int | None = Field(default=0, description="Position of segment in document")
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
position: int | None = Field(default=0, description="Position of the document in the dataset")
class Config:
populate_by_name = True
class Source(BaseModel):
metadata: SourceMetadata = Field(description="Source metadata information")
title: str = Field(description="Document title")
files: list[Any] | None = Field(default=None, description="Associated file references")
content: str | None = Field(description="Segment content text")
summary: str | None = Field(default=None, description="Content summary if available")
class KnowledgeRetrievalRequest(BaseModel):
tenant_id: str = Field(description="Tenant unique identifier")
user_id: str = Field(description="User unique identifier")
app_id: str = Field(description="Application unique identifier")
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
completion_params: dict[str, Any] | None = Field(
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
)
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
metadata_model_config: ModelConfig | None = Field(
default=None, description="Model config for metadata-based filtering"
)
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
default=None, description="Conditions for filtering by metadata"
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
)
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
class RAGRetrievalProtocol(Protocol):
"""Protocol for RAG-based knowledge retrieval implementations.
Implementations of this protocol handle knowledge retrieval from datasets
including rate limiting, dataset filtering, and document retrieval.
"""
@property
def llm_usage(self) -> LLMUsage:
"""Return accumulated LLM usage for retrieval operations."""
...
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
"""Retrieve knowledge from datasets based on the provided request.
Args:
request: Knowledge retrieval request with search parameters
Returns:
List of sources matching the search criteria
Raises:
RateLimitExceededError: If rate limit is exceeded
ModelNotExistError: If specified model doesn't exist
"""
...

View File

@ -64,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
pause_reasons: Sequence[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""
@ -446,7 +446,7 @@ class GraphRuntimeState:
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
return graph_execution_cls(workflow_id=workflow_id)
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.

View File

@ -2,8 +2,8 @@ from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.system_variable import SystemVariableReadOnlyView
from core.workflow.variables.segments import Segment
class ReadOnlyVariablePool(Protocol):

View File

@ -5,8 +5,8 @@ from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.system_variable import SystemVariableReadOnlyView
from core.workflow.variables.segments import Segment
from .graph_runtime_state import GraphRuntimeState
from .variable_pool import VariablePool

View File

@ -9,6 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@ -16,10 +20,6 @@ from core.workflow.constants import (
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.system_variable import SystemVariable
from core.workflow.variables import Segment, SegmentGroup, VariableBase
from core.workflow.variables.consts import SELECTORS_LENGTH
from core.workflow.variables.segments import FileSegment, ObjectSegment
from core.workflow.variables.variables import RAGPipelineVariableInput, Variable
from factories import variable_factory
VariableValue = Union[str, int, float, dict[str, object], list[object], File]

View File

@ -3,9 +3,9 @@ from collections.abc import Mapping, Sequence
from typing import Literal, NamedTuple
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.runtime import VariablePool
from core.workflow.variables import ArrayFileSegment
from core.workflow.variables.segments import ArrayBooleanSegment, BooleanSegment
from .entities import Condition, SubCondition, SupportedComparisonOperator

View File

@ -2,9 +2,9 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
from core.workflow.variables import VariableBase
from core.workflow.variables.consts import SELECTORS_LENGTH
class VariableLoader(Protocol):

View File

@ -5,7 +5,7 @@ from typing import Any, overload
from pydantic import BaseModel
from core.file.models import File
from core.workflow.variables import Segment
from core.variables import Segment
class WorkflowRuntimeTypeConverter:

View File

@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@ -80,14 +80,8 @@ def init_app(app: DifyApp) -> Celery:
worker_hijack_root_logger=False,
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
task_ignore_result=True,
task_annotations=dify_config.CELERY_TASK_ANNOTATIONS,
)
if dify_config.CELERY_BACKEND == "redis":
celery_app.conf.update(
result_backend_transport_options=broker_transport_options,
)
# Apply SSL configuration if enabled
ssl_options = _get_celery_ssl_options()
if ssl_options:
@ -196,14 +190,6 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
}
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
imports.append("schedule.update_api_token_last_used_task")
beat_schedule["batch_update_api_token_last_used"] = {
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@ -119,7 +119,7 @@ class RedisClientWrapper:
redis_client: RedisClientWrapper = RedisClientWrapper()
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
@ -232,7 +232,7 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
if use_clusters:
return RedisCluster.from_url(pubsub_url)
return redis.Redis.from_url(pubsub_url)
@ -256,19 +256,23 @@ def init_app(app: DifyApp):
redis_client.initialize(client)
app.extensions["redis"] = redis_client
global _pubsub_redis_client
_pubsub_redis_client = client
pubsub_client = client
if dify_config.normalized_pubsub_redis_url:
_pubsub_redis_client = _create_pubsub_client(
pubsub_client = _create_pubsub_client(
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
)
pubsub_redis_client.initialize(pubsub_client)
def get_pubsub_redis_client() -> RedisClientWrapper:
return pubsub_redis_client
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
redis_conn = get_pubsub_redis_client()
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
return RedisBroadcastChannel(_pubsub_redis_client)
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
P = ParamSpec("P")

View File

@ -1,6 +1,6 @@
import functools
from collections.abc import Callable
from typing import ParamSpec, TypeVar, cast
from typing import Any, TypeVar, cast
from opentelemetry.trace import get_tracer
@ -8,8 +8,7 @@ from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound=Callable[..., Any])
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
@ -21,7 +20,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
return _HANDLER_INSTANCES[handler_class]
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], T]:
"""
Decorator that traces a function with an OpenTelemetry span.
@ -31,9 +30,9 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Call
:param handler_class: Optional handler class to use for this span. If None, uses the default SpanHandler.
"""
def decorator(func: Callable[P, R]) -> Callable[P, R]:
def decorator(func: T) -> T:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
return func(*args, **kwargs)
@ -47,6 +46,6 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Call
kwargs=kwargs,
)
return cast(Callable[P, R], wrapper)
return cast(T, wrapper)
return decorator

View File

@ -1,11 +1,9 @@
import inspect
from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
R = TypeVar("R")
class SpanHandler:
"""
@ -33,9 +31,9 @@ class SpanHandler:
def _extract_arguments(
self,
wrapped: Callable[..., R],
args: tuple[object, ...],
kwargs: Mapping[str, object],
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> dict[str, Any] | None:
"""
Extract function arguments using inspect.signature.
@ -64,10 +62,10 @@ class SpanHandler:
def wrapper(
self,
tracer: Any,
wrapped: Callable[..., R],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> R:
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
"""
Fully control the wrapper behavior.

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.util.types import AttributeValue
@ -12,19 +12,16 @@ from models.model import Account
logger = logging.getLogger(__name__)
R = TypeVar("R")
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper(
self,
tracer: Any,
wrapped: Callable[..., R],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> R:
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
if not arguments:

View File

@ -10,10 +10,10 @@ from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from core.file.models import File
from core.variables import Segment
from core.workflow.enums import NodeType
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.variables import Segment
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes

View File

@ -8,9 +8,9 @@ from typing import Any
from opentelemetry.trace import Span
from core.variables import Segment
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.variables import Segment
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import RetrieverAttributes

View File

@ -4,12 +4,8 @@ from uuid import uuid4
from configs import dify_config
from core.file import File
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from core.workflow.variables.exc import VariableError
from core.workflow.variables.segments import (
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
@ -26,8 +22,8 @@ from core.workflow.variables.segments import (
Segment,
StringSegment,
)
from core.workflow.variables.types import SegmentType
from core.workflow.variables.variables import (
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayBooleanVariable,
ArrayFileVariable,
@ -44,6 +40,10 @@ from core.workflow.variables.variables import (
StringVariable,
VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
class UnsupportedSegmentTypeError(Exception):

Some files were not shown because too many files have changed in this diff Show More