feat: add file access controller

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-20 17:29:08 +08:00
parent 6d9b4bb78e
commit 98815b55d0
42 changed files with 1852 additions and 1186 deletions

View File

@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.file_access import DatabaseFileAccessController
from core.helper.trace_id_helper import get_external_trace_id
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
return file_objs

View File

@ -15,6 +15,7 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.file_access import DatabaseFileAccessController
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.file import helpers as file_helpers
from dify_graph.variables.segment_group import SegmentGroup
@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -389,13 +391,21 @@ class VariableApi(Resource):
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -21,6 +21,7 @@ from controllers.console.app.workflow_draft_variable import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.file_access import DatabaseFileAccessController
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.types import SegmentType
from extensions.ext_database import db
@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
def _create_pagination_parser():
@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource):
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ModelConfigWithCredentialsEntity,
)
from core.app.file_access import DatabaseFileAccessController
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
@ -46,6 +47,7 @@ from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class BaseAgentRunner(AppRunner):
@ -524,7 +526,10 @@ class BaseAgentRunner(AppRunner):
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
message_files=files,
tenant_id=self.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
if not file_objs:
return UserPromptMessage(content=message.query)

View File

@ -147,85 +147,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
@ -457,94 +459,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = conversation is None
with self._bind_file_access_scope(
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
invoke_from=invoke_from,
):
is_first_conversation = conversation is None
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# get conversation dialogue count
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# get conversation dialogue count
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread.start()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
# message_ = session.get(Message, message.id)
# assert message_ is not None
# message = message_
# db.session.refresh(workflow)
# db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
worker_thread.start()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
db.session.close()
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(
override_model_config_dict or app_model_config.to_dict()
)
else:
file_objs = []
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager,
)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@ -1,4 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
@ -8,7 +9,8 @@ from core.app.apps.draft_variable_saver import (
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
from dify_graph.enums import NodeType
from dify_graph.file import File, FileUploadConfig
from dify_graph.variables.input_entities import VariableEntityType
@ -57,6 +59,31 @@ class _DebuggerDraftVariableSaver:
class BaseAppGenerator:
_file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController()
@staticmethod
def _bind_file_access_scope(
*,
tenant_id: str,
user: Account | EndUser,
invoke_from: InvokeFrom,
) -> AbstractContextManager[None]:
"""Bind request-scoped file ownership markers for downstream file lookups."""
user_id = getattr(user, "id", None)
if not isinstance(user_id, str) or not user_id:
return nullcontext()
user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER
return bind_file_access_scope(
FileAccessScope(
tenant_id=tenant_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
)
)
def _prepare_user_inputs(
self,
*,
@ -85,6 +112,7 @@ class BaseAppGenerator:
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
strict_type_validation=strict_type_validation,
access_controller=self._file_access_controller,
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
@ -99,6 +127,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
access_controller=self._file_access_controller,
)
for k, v in user_inputs.items()
if isinstance(v, list)

View File

@ -1,3 +1,4 @@
import contextvars
import logging
import threading
import uuid
@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(
override_model_config_dict or app_model_config.to_dict()
)
else:
file_objs = []
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=streaming,
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=streaming,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
context = contextvars.copy_context()
worker_thread.start()
# new thread with request context
@copy_current_request_context
def worker_with_context():
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
worker_thread = threading.Thread(target=worker_with_context)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@ -1,3 +1,4 @@
import contextvars
import logging
import threading
import uuid
@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(
override_model_config_dict or app_model_config.to_dict()
)
else:
file_objs = []
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={},
trace_manager=trace_manager,
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={},
trace_manager=trace_manager,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
context = contextvars.copy_context()
worker_thread.start()
# new thread with request context
@copy_current_request_context
def worker_with_context():
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
worker_thread = threading.Thread(target=worker_with_context)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict
# parse files
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
# parse files
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={},
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={},
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
context = contextvars.copy_context()
worker_thread.start()
# new thread with request context
@copy_current_request_context
def worker_with_context():
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)
worker_thread = threading.Thread(target=worker_with_context)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@ -128,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
inputs: Mapping[str, Any] = args["inputs"]
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
access_controller=self._file_access_controller,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
inputs: Mapping[str, Any] = args["inputs"]
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(
self,
@ -291,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
with self._bind_file_access_scope(
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
invoke_from=invoke_from,
):
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# release database connection, because the following new thread operations may take a long time
db.session.close()
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
# release database connection, because the following new thread operations may take a long time
db.session.close()
worker_thread.start()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,

View File

@ -0,0 +1,11 @@
from .controller import DatabaseFileAccessController
from .protocols import FileAccessControllerProtocol
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
__all__ = [
"DatabaseFileAccessController",
"FileAccessControllerProtocol",
"FileAccessScope",
"bind_file_access_scope",
"get_current_file_access_scope",
]

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from collections.abc import Callable
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
from .protocols import FileAccessControllerProtocol
from .scope import FileAccessScope, get_current_file_access_scope
class DatabaseFileAccessController(FileAccessControllerProtocol):
"""Workflow-layer authorization helper for database-backed file lookups.
Tenant scoping remains mandatory. When the current execution belongs to an
end user, the lookup is additionally constrained to that end user's file
ownership markers.
"""
_scope_getter: Callable[[], FileAccessScope | None]
def __init__(
self,
*,
scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope,
) -> None:
self._scope_getter = scope_getter
def current_scope(self) -> FileAccessScope | None:
return self._scope_getter()
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[UploadFile]]:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return stmt
scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id)
if not resolved_scope.requires_user_ownership:
return scoped_stmt
return scoped_stmt.where(
UploadFile.created_by_role == CreatorUserRole.END_USER,
UploadFile.created_by == resolved_scope.user_id,
)
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[ToolFile]]:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return stmt
scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id)
if not resolved_scope.requires_user_ownership:
return scoped_stmt
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
def get_upload_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> UploadFile | None:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return session.get(UploadFile, file_id)
stmt = self.apply_upload_file_filters(
select(UploadFile).where(UploadFile.id == file_id),
scope=resolved_scope,
)
return session.scalar(stmt)
def get_tool_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> ToolFile | None:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return session.get(ToolFile, file_id)
stmt = self.apply_tool_file_filters(
select(ToolFile).where(ToolFile.id == file_id),
scope=resolved_scope,
)
return session.scalar(stmt)

View File

@ -0,0 +1,81 @@
from __future__ import annotations
from typing import Protocol
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from models import ToolFile, UploadFile
from .scope import FileAccessScope
class FileAccessControllerProtocol(Protocol):
"""Contract for applying access rules to file lookups.
Implementations translate an optional execution scope into query constraints
and authorized record retrieval. The contract is intentionally limited to
ownership and tenancy rules for workflow-layer file access.
"""
def current_scope(self) -> FileAccessScope | None:
"""Return the scope active for the current execution, if one exists.
Callers use this to decide whether embedded file metadata may be trusted
or whether a fresh authorized lookup is required.
"""
...
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[UploadFile]]:
"""Return an upload-file query constrained by the supplied access scope.
The returned statement must preserve the caller's existing predicates and
append only access-control conditions.
"""
...
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[ToolFile]]:
"""Return a tool-file query constrained by the supplied access scope.
The returned statement must preserve the caller's existing predicates and
append only access-control conditions.
"""
...
def get_upload_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> UploadFile | None:
"""Load one authorized upload-file record for the given identifier.
Returns ``None`` when the file does not exist or when the scope does not
permit access to that record.
"""
...
def get_tool_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> ToolFile | None:
"""Load one authorized tool-file record for the given identifier.
Returns ``None`` when the file does not exist or when the scope does not
permit access to that record.
"""
...

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
"current_file_access_scope",
default=None,
)
@dataclass(frozen=True, slots=True)
class FileAccessScope:
"""Request-scoped ownership context used by workflow-layer file lookups."""
tenant_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
@property
def requires_user_ownership(self) -> bool:
return self.user_from == UserFrom.END_USER
def get_current_file_access_scope() -> FileAccessScope | None:
return _current_file_access_scope.get()
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
token = _current_file_access_scope.set(scope)
try:
yield
finally:
_current_file_access_scope.reset(token)

View File

@ -10,6 +10,7 @@ from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.signature import sign_tool_file
@ -18,14 +19,22 @@ from dify_graph.file.enums import FileTransferMethod
from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from dify_graph.file.runtime import set_workflow_file_runtime
from extensions.ext_storage import storage
from models import ToolFile, UploadFile
if TYPE_CHECKING:
from dify_graph.file.models import File
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
"""Production runtime wiring for ``dify_graph.file``."""
"""Production runtime wiring for ``dify_graph.file``.
When a request-scoped file access scope is present, opaque file references are
re-validated against the database before URLs are signed or storage keys are used.
"""
_file_access_controller: FileAccessControllerProtocol
def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None:
self._file_access_controller = file_access_controller
@property
def multimodal_send_format(self) -> str:
@ -55,7 +64,16 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
upload_file_id=parsed_reference.record_id,
for_external=for_external,
)
if file.transfer_method in {FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE}:
if file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
if file.extension is None:
raise ValueError("Missing file extension")
self._assert_upload_file_access(upload_file_id=parsed_reference.record_id)
return sign_tool_file(
tool_file_id=parsed_reference.record_id,
extension=file.extension,
for_external=for_external,
)
if file.transfer_method == FileTransferMethod.TOOL_FILE:
if file.extension is None:
raise ValueError("Missing file extension")
return self.resolve_tool_file_url(
@ -72,6 +90,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
as_attachment: bool = False,
for_external: bool = True,
) -> str:
self._assert_upload_file_access(upload_file_id=upload_file_id)
base_url = self._base_url(for_external=for_external)
url = f"{base_url}/files/{upload_file_id}/file-preview"
query = self._sign_query(payload=f"file-preview|{upload_file_id}")
@ -80,6 +99,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return f"{url}?{urllib.parse.urlencode(query)}"
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._assert_tool_file_access(tool_file_id=tool_file_id)
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
def verify_preview_signature(
@ -121,7 +141,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("Missing file reference")
if parsed_reference.storage_key:
if parsed_reference.storage_key and self._file_access_controller.current_scope() is None:
return parsed_reference.storage_key
record_id = parsed_reference.record_id
@ -131,16 +151,34 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
FileTransferMethod.REMOTE_URL,
FileTransferMethod.DATASOURCE_FILE,
}:
upload_file = session.get(UploadFile, record_id)
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id)
if upload_file is None:
raise ValueError(f"Upload file {record_id} not found")
return upload_file.key
tool_file = session.get(ToolFile, record_id)
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id)
if tool_file is None:
raise ValueError(f"Tool file {record_id} not found")
return tool_file.file_key
def _assert_upload_file_access(self, *, upload_file_id: str) -> None:
if self._file_access_controller.current_scope() is None:
return
with session_factory.create_session() as session:
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id)
if upload_file is None:
raise ValueError(f"Upload file {upload_file_id} not found")
def _assert_tool_file_access(self, *, tool_file_id: str) -> None:
if self._file_access_controller.current_scope() is None:
return
with session_factory.create_session() as session:
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id)
if tool_file is None:
raise ValueError(f"Tool file {tool_file_id} not found")
def bind_dify_workflow_file_runtime() -> None:
set_workflow_file_runtime(DifyWorkflowFileRuntime())
set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()))

View File

@ -6,6 +6,7 @@ from typing import Any, cast
from sqlalchemy import select
import contexts
from core.app.file_access import DatabaseFileAccessController
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import (
@ -37,6 +38,7 @@ from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class DatasourceManager:
@ -284,7 +286,11 @@ class DatasourceManager:
"transfer_method": FileTransferMethod.TOOL_FILE,
"url": url,
}
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
file_out = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
elif mtype == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)

View File

@ -4,6 +4,7 @@ from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.file_access import DatabaseFileAccessController
from core.model_manager import ModelInstance
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from dify_graph.file import file_manager
@ -23,6 +24,8 @@ from models.workflow import Workflow
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
_file_access_controller = DatabaseFileAccessController()
class TokenBufferMemory:
def __init__(
@ -85,7 +88,10 @@ class TokenBufferMemory:
# Build files directly without filtering by belongs_to
file_objs = [
file_factory.build_from_message_file(
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
message_file=message_file,
tenant_id=app_record.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
for message_file in message_files
]

View File

@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.file_access import DatabaseFileAccessController
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
@ -49,6 +50,8 @@ from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
@ -556,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
file_obj = build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
file_objects.append(file_obj)
except Exception as e:

View File

@ -7,6 +7,7 @@ from typing import Any, cast
from sqlalchemy import select
from core.app.file_access import DatabaseFileAccessController
from core.db.session_factory import session_factory
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
@ -26,6 +27,7 @@ from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class WorkflowTool(Tool):
@ -331,6 +333,7 @@ class WorkflowTool(Tool):
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
@ -338,6 +341,7 @@ class WorkflowTool(Tool):
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)

View File

@ -8,6 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.file_access import DatabaseFileAccessController
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@ -82,6 +83,9 @@ if TYPE_CHECKING:
from dify_graph.nodes.tool.entities import ToolNodeData
_file_access_controller = DatabaseFileAccessController()
def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext:
if isinstance(run_context, DifyRunContext):
return run_context
@ -127,6 +131,7 @@ class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
return file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self._run_context.tenant_id,
access_controller=_file_access_controller,
)

View File

@ -6,6 +6,7 @@ from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.file_access import DatabaseFileAccessController
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -27,6 +28,8 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
_file_access_controller = DatabaseFileAccessController()
class AgentMessageTransformer:
def transform(
@ -93,6 +96,7 @@ class AgentMessageTransformer:
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
@ -116,6 +120,7 @@ class AgentMessageTransformer:
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:

View File

@ -7,6 +7,7 @@ from configs import dify_config
from context import capture_current_context
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.file_access import DatabaseFileAccessController
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class
@ -35,6 +36,7 @@ from factories import file_factory
from models.workflow import Workflow
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class _WorkflowChildEngineBuilder:
@ -508,13 +510,21 @@ class WorkflowEntry:
continue
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
input_value = file_factory.build_from_mapping(
mapping=input_value,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
if (
isinstance(input_value, list)
and all(isinstance(item, dict) for item in input_value)
and all("type" in item and "transfer_method" in item for item in input_value)
):
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
input_value = file_factory.build_from_mappings(
mappings=input_value,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
# append variable and value to variable pool
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:

View File

@ -1,582 +0,0 @@
import logging
import mimetypes
import os
import re
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
from typing import Any
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.http import parse_options_header
from core.helper import ssrf_proxy
from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
from dify_graph.file.file_factory import standardize_file_type
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
logger = logging.getLogger(__name__)
def _resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None:
for key in (*keys, "reference", "related_id"):
raw_value = mapping.get(key)
if isinstance(raw_value, str) and raw_value:
resolved_value = resolve_file_record_id(raw_value)
if resolved_value:
return resolved_value
return None
def build_from_message_files(
*,
message_files: Sequence["MessageFile"],
tenant_id: str,
config: FileUploadConfig | None = None,
) -> Sequence[File]:
results = [
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
for file in message_files
if file.belongs_to != FileBelongsTo.ASSISTANT
]
return results
def build_from_message_file(
*,
message_file: "MessageFile",
tenant_id: str,
config: FileUploadConfig | None,
):
mapping = {
"transfer_method": message_file.transfer_method,
"url": message_file.url,
"type": message_file.type,
}
# Only include id if it exists (message_file has been committed to DB)
if message_file.id:
mapping["id"] = message_file.id
# Set the correct ID field based on transfer method
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
mapping["tool_file_id"] = message_file.upload_file_id
else:
mapping["upload_file_id"] = message_file.upload_file_id
return build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
)
def build_from_mapping(
*,
mapping: Mapping[str, Any],
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
) -> File:
transfer_method_value = mapping.get("transfer_method")
if not transfer_method_value:
raise ValueError("transfer_method is required in file mapping")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
build_functions: dict[FileTransferMethod, Callable] = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
}
build_func = build_functions.get(transfer_method)
if not build_func:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
file: File = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
)
if config and not _is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
file_extension=file.extension or "",
file_transfer_method=file.transfer_method,
config=config,
):
raise ValueError(f"File validation failed for file: {file.filename}")
return file
def build_from_mappings(
*,
mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None = None,
tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]:
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
def is_valid_mapping(m: Mapping[str, Any]) -> bool:
if not m or not m.get("transfer_method"):
return False
# For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
transfer_method = m.get("transfer_method")
if transfer_method == FileTransferMethod.REMOTE_URL:
url = m.get("url") or m.get("remote_url")
if not url:
return False
return True
valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
strict_type_validation=strict_type_validation,
)
for mapping in valid_mappings
]
if (
config
# If image config is set.
and config.image_config
# And the number of image files exceeds the maximum limit
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
):
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
if config and config.number_limits and len(files) > config.number_limits:
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
return files
def _build_from_local_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = _resolve_mapping_file_id(mapping, "upload_file_id")
if not upload_file_id:
raise ValueError("Invalid upload file id")
# check if upload_file_id is a valid uuid
try:
uuid.UUID(upload_file_id)
except ValueError:
raise ValueError("Invalid upload file id format")
stmt = select(UploadFile).where(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
row = db.session.scalar(stmt)
if row is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
specified_type = mapping.get("type", "custom")
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
if specified_type and specified_type != "custom":
file_type = FileType(specified_type)
else:
file_type = detected_file_type
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id), storage_key=row.key),
size=row.size,
)
def _build_from_remote_url(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = _resolve_mapping_file_id(mapping, "upload_file_id")
if upload_file_id:
try:
uuid.UUID(upload_file_id)
except ValueError:
raise ValueError("Invalid upload file id format")
stmt = select(UploadFile).where(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
upload_file = db.session.scalar(stmt)
if upload_file is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(
extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
if specified_type and specified_type != "custom":
file_type = FileType(specified_type)
else:
file_type = detected_file_type
return File(
id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key),
size=upload_file.size,
)
url = mapping.get("url") or mapping.get("remote_url")
if not url:
raise ValueError("Invalid file url")
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
if specified_type and specified_type != "custom":
file_type = FileType(specified_type)
else:
file_type = detected_file_type
return File(
id=mapping.get("id"),
filename=filename,
type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
extension=extension,
size=file_size,
)
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
filename: str | None = None
# Try to extract from Content-Disposition header first
if content_disposition:
# Manually extract filename* parameter since parse_options_header doesn't support it
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
if filename_star_match:
raw_star = filename_star_match.group(1).strip()
# Remove trailing quotes if present
raw_star = raw_star.removesuffix('"')
# format: charset'lang'value
try:
parts = raw_star.split("'", 2)
charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
value = parts[2] if len(parts) == 3 else parts[-1]
filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
except Exception:
# Fallback: try to extract value after the last single quote
if "''" in raw_star:
filename = urllib.parse.unquote(raw_star.split("''")[-1])
else:
filename = urllib.parse.unquote(raw_star)
if not filename:
# Fallback to regular filename parameter
_, params = parse_options_header(content_disposition)
raw = params.get("filename")
if raw:
# Strip surrounding quotes and percent-decode if present
if len(raw) >= 2 and raw[0] == raw[-1] == '"':
raw = raw[1:-1]
filename = urllib.parse.unquote(raw)
# Fallback to URL path if no filename from header
if not filename:
candidate = os.path.basename(url_path)
filename = urllib.parse.unquote(candidate) if candidate else None
# Defense-in-depth: ensure basename only
if filename:
filename = os.path.basename(filename)
# Return None if filename is empty or only whitespace
if not filename or not filename.strip():
filename = None
return filename or None
def _guess_mime_type(filename: str) -> str:
"""Guess MIME type from filename, returning empty string if None."""
guessed_mime, _ = mimetypes.guess_type(filename)
return guessed_mime or ""
def _get_remote_file_info(url: str):
file_size = -1
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
# Initialize mime_type from filename as fallback
mime_type = _guess_mime_type(filename)
resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
content_disposition = resp.headers.get("Content-Disposition")
extracted_filename = _extract_filename(url_path, content_disposition)
if extracted_filename:
filename = extracted_filename
mime_type = _guess_mime_type(filename)
file_size = int(resp.headers.get("Content-Length", file_size))
# Fallback to Content-Type header if mime_type is still empty
if not mime_type:
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
if not filename:
extension = mimetypes.guess_extension(mime_type) or ".bin"
filename = f"{uuid.uuid4().hex}{extension}"
if not mime_type:
mime_type = _guess_mime_type(filename)
return mime_type, filename, file_size
def _build_from_tool_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file_id = _resolve_mapping_file_id(mapping, "tool_file_id")
if not tool_file_id:
raise ValueError(f"ToolFile {tool_file_id} not found")
tool_file = db.session.scalar(
select(ToolFile).where(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
)
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
if specified_type and specified_type != "custom":
file_type = FileType(specified_type)
else:
file_type = detected_file_type
return File(
id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id), storage_key=tool_file.file_key),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
def _build_from_datasource_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
datasource_file_id = _resolve_mapping_file_id(mapping, "datasource_file_id")
if not datasource_file_id:
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
datasource_file = db.session.scalar(
select(UploadFile).where(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
)
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
if specified_type and specified_type != "custom":
file_type = FileType(specified_type)
else:
file_type = detected_file_type
return File(
id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id), storage_key=datasource_file.key),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
url=datasource_file.source_url,
)
def _is_file_valid_with_config(
*,
input_file_type: str,
file_extension: str,
file_transfer_method: FileTransferMethod,
config: FileUploadConfig,
) -> bool:
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
# These are internally generated and should bypass user upload restrictions
if file_transfer_method == FileTransferMethod.TOOL_FILE:
return True
if (
config.allowed_file_types
and input_file_type not in config.allowed_file_types
and input_file_type != FileType.CUSTOM
):
return False
if (
input_file_type == FileType.CUSTOM
and config.allowed_file_extensions is not None
and file_extension not in config.allowed_file_extensions
):
return False
if input_file_type == FileType.IMAGE:
if (
config.image_config
and config.image_config.transfer_methods
and file_transfer_method not in config.image_config.transfer_methods
):
return False
elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
return False
return True
class StorageKeyLoader:
"""FileKeyLoader load the storage key from database for a list of files.
This loader is batched, the database query count is constant regardless of the input size.
"""
def __init__(self, session: Session, tenant_id: str):
self._session = session
self._tenant_id = tenant_id
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
stmt = select(UploadFile).where(
UploadFile.id.in_(upload_file_ids),
UploadFile.tenant_id == self._tenant_id,
)
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
stmt = select(ToolFile).where(
ToolFile.id.in_(tool_file_ids),
ToolFile.tenant_id == self._tenant_id,
)
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
def load_storage_keys(self, files: Sequence[File]):
"""Loads storage keys for a sequence of files by retrieving the corresponding
`UploadFile` or `ToolFile` records from the database based on their transfer method.
This method doesn't modify the input sequence structure but updates the `_storage_key`
property of each file object by extracting the relevant key from its database record.
Tenant scoping is enforced by this loader's context rather than by embedding tenant
identity inside graph-layer ``File`` values.
Performance note: This is a batched operation where database query count remains constant
regardless of input size. However, for optimal performance, input sequences should contain
fewer than 1000 files. For larger collections, split into smaller batches and process each
batch separately.
"""
upload_file_ids: list[uuid.UUID] = []
tool_file_ids: list[uuid.UUID] = []
for file in files:
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("file id should not be None.")
model_id = uuid.UUID(parsed_reference.record_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_ids.append(model_id)
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_ids.append(model_id)
tool_files = self._load_tool_files(tool_file_ids)
upload_files = self._load_upload_files(upload_file_ids)
for file in files:
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("file id should not be None.")
model_id = uuid.UUID(parsed_reference.record_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(model_id)
if upload_file_row is None:
raise ValueError(f"Upload file not found for id: {model_id}")
file.reference = build_file_reference(
record_id=str(upload_file_row.id),
storage_key=upload_file_row.key,
)
file.storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id)
if tool_file_row is None:
raise ValueError(f"Tool file not found for id: {model_id}")
file.reference = build_file_reference(
record_id=str(tool_file_row.id),
storage_key=tool_file_row.file_key,
)
file.storage_key = tool_file_row.file_key

View File

@ -0,0 +1,31 @@
"""Workflow file factory package.
This package normalizes workflow-layer file payloads into graph-layer ``File``
values. It keeps tenancy and ownership checks in the application layer and
preserves the historical ``factories.file_factory`` import surface for callers.
"""
from core.helper import ssrf_proxy
from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig
from extensions.ext_database import db
from .builders import build_from_mapping, build_from_mappings
from .message_files import build_from_message_file, build_from_message_files
from .remote import _extract_filename, _get_remote_file_info
from .storage_keys import StorageKeyLoader
__all__ = [
"File",
"FileTransferMethod",
"FileType",
"FileUploadConfig",
"StorageKeyLoader",
"_extract_filename",
"_get_remote_file_info",
"build_from_mapping",
"build_from_mappings",
"build_from_message_file",
"build_from_message_files",
"db",
"ssrf_proxy",
]

View File

@ -0,0 +1,325 @@
"""Core builders for workflow file mappings."""
from __future__ import annotations
import mimetypes
import uuid
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import select
from core.app.file_access import FileAccessControllerProtocol
from core.workflow.file_reference import build_file_reference
from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers
from dify_graph.file.file_factory import standardize_file_type
from extensions.ext_database import db
from models import ToolFile, UploadFile
from .common import resolve_mapping_file_id
from .remote import _get_remote_file_info
from .validation import is_file_valid_with_config
def build_from_mapping(
*,
mapping: Mapping[str, Any],
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
access_controller: FileAccessControllerProtocol,
) -> File:
transfer_method_value = mapping.get("transfer_method")
if not transfer_method_value:
raise ValueError("transfer_method is required in file mapping")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
build_func = _get_build_function(transfer_method)
file = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
if config and not is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
file_extension=file.extension or "",
file_transfer_method=file.transfer_method,
config=config,
):
raise ValueError(f"File validation failed for file: {file.filename}")
return file
def build_from_mappings(
*,
mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None = None,
tenant_id: str,
strict_type_validation: bool = False,
access_controller: FileAccessControllerProtocol,
) -> Sequence[File]:
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)]
files = [
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
for mapping in valid_mappings
]
if (
config
and config.image_config
and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits
):
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
if config and config.number_limits and len(files) > config.number_limits:
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
return files
def _get_build_function(transfer_method: FileTransferMethod):
build_functions = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
}
build_func = build_functions.get(transfer_method)
if build_func is None:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
return build_func
def _resolve_file_type(
*,
detected_file_type: FileType,
specified_type: str | None,
strict_type_validation: bool,
) -> FileType:
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
if specified_type and specified_type != "custom":
return FileType(specified_type)
return detected_file_type
def _build_from_local_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
access_controller: FileAccessControllerProtocol,
) -> File:
upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id")
if not upload_file_id:
raise ValueError("Invalid upload file id")
try:
uuid.UUID(upload_file_id)
except ValueError as exc:
raise ValueError("Invalid upload file id format") from exc
stmt = select(UploadFile).where(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id), storage_key=row.key),
size=row.size,
)
def _build_from_remote_url(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
access_controller: FileAccessControllerProtocol,
) -> File:
upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id")
if upload_file_id:
try:
uuid.UUID(upload_file_id)
except ValueError as exc:
raise ValueError("Invalid upload file id format") from exc
stmt = select(UploadFile).where(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key),
size=upload_file.size,
)
url = mapping.get("url") or mapping.get("remote_url")
if not url:
raise ValueError("Invalid file url")
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=filename,
type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
extension=extension,
size=file_size,
)
def _build_from_tool_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
access_controller: FileAccessControllerProtocol,
) -> File:
tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id")
if not tool_file_id:
raise ValueError(f"ToolFile {tool_file_id} not found")
stmt = select(ToolFile).where(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id), storage_key=tool_file.file_key),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
def _build_from_datasource_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
access_controller: FileAccessControllerProtocol,
) -> File:
datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id")
if not datasource_file_id:
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
stmt = select(UploadFile).where(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id), storage_key=datasource_file.key),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
url=datasource_file.source_url,
)
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
if not mapping or not mapping.get("transfer_method"):
return False
if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL:
url = mapping.get("url") or mapping.get("remote_url")
if not url:
return False
return True

View File

@ -0,0 +1,27 @@
"""Shared helpers for workflow file factory modules."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from core.workflow.file_reference import resolve_file_record_id
def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None:
"""Resolve historical file identifiers from persisted mapping payloads.
Workflow and model payloads can outlive file schema changes. Older rows may
still carry concrete identifiers in legacy fields such as ``related_id``,
while newer payloads use opaque references. Keep this compatibility lookup in
the factory layer so historical data remains readable without reintroducing
storage details into graph-layer ``File`` values.
"""
for key in (*keys, "reference", "related_id"):
raw_value = mapping.get(key)
if isinstance(raw_value, str) and raw_value:
resolved_value = resolve_file_record_id(raw_value)
if resolved_value:
return resolved_value
return None

View File

@ -0,0 +1,59 @@
"""Adapters from persisted message files to graph-layer file values."""
from __future__ import annotations
from collections.abc import Sequence
from core.app.file_access import FileAccessControllerProtocol
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig
from models import MessageFile
from .builders import build_from_mapping
def build_from_message_files(
*,
message_files: Sequence[MessageFile],
tenant_id: str,
config: FileUploadConfig | None = None,
access_controller: FileAccessControllerProtocol,
) -> Sequence[File]:
return [
build_from_message_file(
message_file=message_file,
tenant_id=tenant_id,
config=config,
access_controller=access_controller,
)
for message_file in message_files
if message_file.belongs_to != FileBelongsTo.ASSISTANT
]
def build_from_message_file(
*,
message_file: MessageFile,
tenant_id: str,
config: FileUploadConfig | None,
access_controller: FileAccessControllerProtocol,
) -> File:
mapping = {
"transfer_method": message_file.transfer_method,
"url": message_file.url,
"type": message_file.type,
}
if message_file.id:
mapping["id"] = message_file.id
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
mapping["tool_file_id"] = message_file.upload_file_id
else:
mapping["upload_file_id"] = message_file.upload_file_id
return build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
access_controller=access_controller,
)

View File

@ -0,0 +1,84 @@
"""Remote file metadata helpers used by workflow file normalization."""
from __future__ import annotations
import mimetypes
import os
import re
import urllib.parse
import uuid
import httpx
from werkzeug.http import parse_options_header
from core.helper import ssrf_proxy
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
filename: str | None = None
if content_disposition:
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
if filename_star_match:
raw_star = filename_star_match.group(1).strip()
raw_star = raw_star.removesuffix('"')
try:
parts = raw_star.split("'", 2)
charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
value = parts[2] if len(parts) == 3 else parts[-1]
filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
except Exception:
if "''" in raw_star:
filename = urllib.parse.unquote(raw_star.split("''")[-1])
else:
filename = urllib.parse.unquote(raw_star)
if not filename:
_, params = parse_options_header(content_disposition)
raw = params.get("filename")
if raw:
if len(raw) >= 2 and raw[0] == raw[-1] == '"':
raw = raw[1:-1]
filename = urllib.parse.unquote(raw)
if not filename:
candidate = os.path.basename(url_path)
filename = urllib.parse.unquote(candidate) if candidate else None
if filename:
filename = os.path.basename(filename)
if not filename or not filename.strip():
filename = None
return filename or None
def _guess_mime_type(filename: str) -> str:
guessed_mime, _ = mimetypes.guess_type(filename)
return guessed_mime or ""
def _get_remote_file_info(url: str) -> tuple[str, str, int]:
file_size = -1
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
mime_type = _guess_mime_type(filename)
resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
content_disposition = resp.headers.get("Content-Disposition")
extracted_filename = _extract_filename(url_path, content_disposition)
if extracted_filename:
filename = extracted_filename
mime_type = _guess_mime_type(filename)
file_size = int(resp.headers.get("Content-Length", file_size))
if not mime_type:
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
if not filename:
extension = mimetypes.guess_extension(mime_type) or ".bin"
filename = f"{uuid.uuid4().hex}{extension}"
if not mime_type:
mime_type = _guess_mime_type(filename)
return mime_type, filename, file_size

View File

@ -0,0 +1,107 @@
"""Batched storage-key hydration for workflow files."""
from __future__ import annotations
import uuid
from collections.abc import Mapping, Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.file_access import FileAccessControllerProtocol
from core.workflow.file_reference import build_file_reference, parse_file_reference
from dify_graph.file import File, FileTransferMethod
from models import ToolFile, UploadFile
class StorageKeyLoader:
"""Load storage keys for files with a constant number of database queries."""
_session: Session
_tenant_id: str
_access_controller: FileAccessControllerProtocol
def __init__(
self,
session: Session,
tenant_id: str,
access_controller: FileAccessControllerProtocol,
) -> None:
self._session = session
self._tenant_id = tenant_id
self._access_controller = access_controller
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
stmt = select(UploadFile).where(
UploadFile.id.in_(upload_file_ids),
UploadFile.tenant_id == self._tenant_id,
)
scoped_stmt = self._access_controller.apply_upload_file_filters(stmt)
return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)}
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
stmt = select(ToolFile).where(
ToolFile.id.in_(tool_file_ids),
ToolFile.tenant_id == self._tenant_id,
)
scoped_stmt = self._access_controller.apply_tool_file_filters(stmt)
return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)}
def load_storage_keys(self, files: Sequence[File]) -> None:
"""Hydrate storage keys by loading their backing file rows in batches.
The sequence shape is preserved. Each file is updated in place with a
canonical reference and storage key loaded from an authorized database
row. Tenant scoping is enforced by this loader's context rather than by
embedding tenant identity inside graph-layer ``File`` values.
For best performance, prefer batches smaller than 1000 files.
"""
upload_file_ids: list[uuid.UUID] = []
tool_file_ids: list[uuid.UUID] = []
for file in files:
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("file id should not be None.")
model_id = uuid.UUID(parsed_reference.record_id)
if file.transfer_method in (
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
FileTransferMethod.DATASOURCE_FILE,
):
upload_file_ids.append(model_id)
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_ids.append(model_id)
tool_files = self._load_tool_files(tool_file_ids)
upload_files = self._load_upload_files(upload_file_ids)
for file in files:
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("file id should not be None.")
model_id = uuid.UUID(parsed_reference.record_id)
if file.transfer_method in (
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
FileTransferMethod.DATASOURCE_FILE,
):
upload_file_row = upload_files.get(model_id)
if upload_file_row is None:
raise ValueError(f"Upload file not found for id: {model_id}")
file.reference = build_file_reference(
record_id=str(upload_file_row.id),
storage_key=upload_file_row.key,
)
file.storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id)
if tool_file_row is None:
raise ValueError(f"Tool file not found for id: {model_id}")
file.reference = build_file_reference(
record_id=str(tool_file_row.id),
storage_key=tool_file_row.file_key,
)
file.storage_key = tool_file_row.file_key

View File

@ -0,0 +1,44 @@
"""Validation helpers for workflow file inputs."""
from __future__ import annotations
from dify_graph.file import FileTransferMethod, FileType, FileUploadConfig
def is_file_valid_with_config(
*,
input_file_type: str,
file_extension: str,
file_transfer_method: FileTransferMethod,
config: FileUploadConfig,
) -> bool:
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
# These are internally generated and should bypass user upload restrictions
if file_transfer_method == FileTransferMethod.TOOL_FILE:
return True
if (
config.allowed_file_types
and input_file_type not in config.allowed_file_types
and input_file_type != FileType.CUSTOM
):
return False
if (
input_file_type == FileType.CUSTOM
and config.allowed_file_extensions is not None
and file_extension not in config.allowed_file_extensions
):
return False
if input_file_type == FileType.IMAGE:
if (
config.image_config
and config.image_config.transfer_methods
and file_transfer_method not in config.image_config.transfer_methods
):
return False
elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
return False
return True

View File

@ -7,6 +7,7 @@ from collections.abc import Callable, Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from uuid import uuid4
@ -53,6 +54,13 @@ if TYPE_CHECKING:
# --- TypedDict definitions for structured dict return types ---
@lru_cache(maxsize=1)
def _get_file_access_controller():
from core.app.file_access import DatabaseFileAccessController
return DatabaseFileAccessController()
def _resolve_app_tenant_id(app_id: str) -> str:
resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id))
if not resolved_tenant_id:
@ -1618,6 +1626,7 @@ class Message(Base):
"upload_file_id": message_file.upload_file_id,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
if message_file.url is None:
@ -1631,6 +1640,7 @@ class Message(Base):
"url": message_file.url,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
if message_file.upload_file_id is None:
@ -1645,6 +1655,7 @@ class Message(Base):
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
else:
raise ValueError(

View File

@ -1,12 +1,20 @@
from __future__ import annotations
from collections.abc import Callable, Mapping
from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
from dify_graph.file import File, FileTransferMethod
@lru_cache(maxsize=1)
def _get_file_access_controller():
from core.app.file_access import DatabaseFileAccessController
return DatabaseFileAccessController()
def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None:
reference = file_mapping.get("reference")
if isinstance(reference, str) and reference:
@ -61,4 +69,8 @@ def build_file_from_input_mapping(
mapping["upload_file_id"] = record_id
tenant_id = resolve_file_mapping_tenant_id(file_mapping=mapping, tenant_resolver=tenant_resolver)
return file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
return file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_get_file_access_controller(),
)

View File

@ -15,6 +15,7 @@ from werkzeug.exceptions import RequestEntityTooLarge
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.file_access import DatabaseFileAccessController
from core.tools.tool_file_manager import ToolFileManager
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from core.workflow.nodes.trigger_webhook.entities import (
@ -46,6 +47,7 @@ except ImportError:
magic = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class WebhookService:
@ -422,6 +424,7 @@ class WebhookService:
return file_factory.build_from_mapping(
mapping=mapping,
tenant_id=webhook_trigger.tenant_id,
access_controller=_file_access_controller,
)
@classmethod

View File

@ -14,6 +14,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.file_access import DatabaseFileAccessController
from core.trigger.constants import is_trigger_node_type
from core.workflow.system_variables import SystemVariableKey
from core.workflow.variable_prefixes import (
@ -126,7 +127,11 @@ class DraftVarLoader(VariableLoader):
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
storage_key_loader = StorageKeyLoader(
session,
tenant_id=self._tenant_id,
access_controller=DatabaseFileAccessController(),
)
storage_key_loader.load_storage_keys(files)
offloaded_draft_vars = []

View File

@ -12,6 +12,7 @@ from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.file_access import DatabaseFileAccessController
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
@ -84,6 +85,8 @@ from .human_input_delivery_test_service import (
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
from .workflow_restore import apply_published_workflow_snapshot_to_draft
_file_access_controller = DatabaseFileAccessController()
class WorkflowService:
"""
@ -1583,7 +1586,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
if variable_entity_type == VariableEntityType.FILE:
if not isinstance(value, dict):
raise ValueError(f"expected dict for file object, got {type(value)}")
return build_from_mapping(mapping=value, tenant_id=tenant_id)
return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller)
elif variable_entity_type == VariableEntityType.FILE_LIST:
if not isinstance(value, list):
raise ValueError(f"expected list for file list object, got {type(value)}")
@ -1591,6 +1594,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
return []
if not isinstance(value[0], dict):
raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
return build_from_mappings(mappings=value, tenant_id=tenant_id)
return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller)
else:
raise Exception("unreachable")

View File

@ -6,6 +6,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.app.file_access import DatabaseFileAccessController
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase):
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)
self.loader = StorageKeyLoader(
self.session,
self.tenant_id,
access_controller=DatabaseFileAccessController(),
)
def tearDown(self):
"""Clean up test data after each test method."""
@ -361,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create loader with different session (same underlying connection)
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
other_loader = StorageKeyLoader(
other_session,
self.tenant_id,
access_controller=DatabaseFileAccessController(),
)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])

View File

@ -6,6 +6,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.app.file_access import DatabaseFileAccessController
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase):
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)
self.loader = StorageKeyLoader(
self.session,
self.tenant_id,
access_controller=DatabaseFileAccessController(),
)
def tearDown(self):
"""Clean up test data after each test method."""
@ -362,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create loader with different session (same underlying connection)
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
other_loader = StorageKeyLoader(
other_session,
self.tenant_id,
access_controller=DatabaseFileAccessController(),
)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])

View File

@ -401,11 +401,11 @@ class TestBaseAppGeneratorExtras:
monkeypatch.setattr(
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object",
)
monkeypatch.setattr(
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
lambda mappings, tenant_id, config: ["file-1", "file-2"],
lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"],
)
user_inputs = {

View File

@ -9,10 +9,13 @@ from urllib.parse import parse_qs, urlparse
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.file_access import DatabaseFileAccessController, FileAccessScope
from core.app.workflow import file_runtime
from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime
from core.workflow.file_reference import build_file_reference
from dify_graph.file import File, FileTransferMethod, FileType
from models import ToolFile, UploadFile
def _build_file(
@ -35,8 +38,12 @@ def _build_file(
)
def _build_runtime() -> DifyWorkflowFileRuntime:
return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())
def test_resolve_file_url_returns_remote_url() -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
file = _build_file(
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/diagram.png",
@ -46,7 +53,7 @@ def test_resolve_file_url_returns_remote_url() -> None:
def test_resolve_file_url_requires_file_reference() -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None)
with pytest.raises(ValueError, match="Missing file reference"):
@ -54,7 +61,7 @@ def test_resolve_file_url_requires_file_reference() -> None:
def test_resolve_file_url_requires_extension_for_tool_files() -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
file = _build_file(
transfer_method=FileTransferMethod.TOOL_FILE,
reference=build_file_reference(record_id="tool-file-id"),
@ -70,7 +77,7 @@ def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files(
) -> None:
sign_tool_file = MagicMock(return_value="https://signed.example.com/file")
monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file)
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
tool_file = _build_file(
transfer_method=FileTransferMethod.TOOL_FILE,
@ -100,7 +107,7 @@ def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments(
"https://internal.example.com",
)
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
url = runtime.resolve_upload_file_url(
upload_file_id="upload-file-id",
as_attachment=True,
@ -119,7 +126,7 @@ def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch
monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000)
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60)
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
payload = "file-preview|upload-file-id|1700000000|nonce"
sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode()
@ -158,7 +165,7 @@ def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch
def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
file = _build_file(
transfer_method=FileTransferMethod.LOCAL_FILE,
reference=build_file_reference(record_id="upload-file-id", storage_key="storage-key"),
@ -173,7 +180,7 @@ def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest
def test_resolve_storage_key_prefers_encoded_reference() -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
file = _build_file(
transfer_method=FileTransferMethod.LOCAL_FILE,
reference=build_file_reference(record_id="upload-file-id", storage_key="storage-key"),
@ -182,6 +189,60 @@ def test_resolve_storage_key_prefers_encoded_reference() -> None:
assert runtime._resolve_storage_key(file=file) == "storage-key"
def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None:
controller = MagicMock()
controller.current_scope.return_value = FileAccessScope(
tenant_id="tenant-id",
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key")
runtime = DifyWorkflowFileRuntime(file_access_controller=controller)
file = _build_file(
transfer_method=FileTransferMethod.LOCAL_FILE,
reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"),
)
session = MagicMock()
class _SessionContext:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext())
assert runtime._resolve_storage_key(file=file) == "canonical-storage-key"
controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id")
def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None:
controller = MagicMock()
controller.current_scope.return_value = FileAccessScope(
tenant_id="tenant-id",
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
controller.get_upload_file.return_value = None
runtime = DifyWorkflowFileRuntime(file_access_controller=controller)
session = MagicMock()
class _SessionContext:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext())
with pytest.raises(ValueError, match="Upload file upload-file-id not found"):
runtime.resolve_upload_file_url(upload_file_id="upload-file-id")
@pytest.mark.parametrize(
("transfer_method", "record_id", "expected_storage_key"),
[
@ -196,7 +257,7 @@ def test_resolve_storage_key_loads_database_records(
record_id: str,
expected_storage_key: str,
) -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
file = _build_file(
transfer_method=transfer_method,
reference=build_file_reference(record_id=record_id),
@ -206,9 +267,9 @@ def test_resolve_storage_key_loads_database_records(
def get(model_class, value):
if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}:
assert model_class is file_runtime.UploadFile
assert model_class is UploadFile
return SimpleNamespace(key="upload-storage-key")
assert model_class is file_runtime.ToolFile
assert model_class is ToolFile
return SimpleNamespace(file_key="tool-storage-key")
session.get.side_effect = get
@ -237,7 +298,7 @@ def test_resolve_storage_key_raises_when_records_are_missing(
transfer_method: FileTransferMethod,
expected_message: str,
) -> None:
runtime = DifyWorkflowFileRuntime()
runtime = _build_runtime()
record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id"
file = _build_file(
transfer_method=transfer_method,

View File

@ -115,6 +115,7 @@ def test_dify_file_reference_factory_passes_tenant_id(monkeypatch: pytest.Monkey
build_from_mapping.assert_called_once_with(
mapping={"id": "upload-file"},
tenant_id="tenant-id",
access_controller=node_runtime._file_access_controller,
)

View File

@ -4,13 +4,17 @@ from unittest.mock import MagicMock, patch
import pytest
from httpx import Response
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
from core.workflow.file_reference import build_file_reference, resolve_file_record_id
from factories.file_factory import (
File,
FileTransferMethod,
FileType,
FileUploadConfig,
build_from_mapping,
)
from factories.file_factory import (
build_from_mapping as _build_from_mapping,
)
from models import ToolFile, UploadFile
@ -19,6 +23,7 @@ TEST_TENANT_ID = "test_tenant_id"
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
TEST_TOOL_FILE_ID = str(uuid.uuid4())
TEST_REMOTE_URL = "http://example.com/test.jpg"
TEST_ACCESS_CONTROLLER = DatabaseFileAccessController()
# Test Config
TEST_CONFIG = FileUploadConfig(
@ -29,6 +34,16 @@ TEST_CONFIG = FileUploadConfig(
)
def build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
return _build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
strict_type_validation=strict_type_validation,
access_controller=TEST_ACCESS_CONTROLLER,
)
# Fixtures
@pytest.fixture
def mock_upload_file():
@ -305,6 +320,48 @@ def test_tenant_mismatch():
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_build_from_mapping_scopes_upload_file_to_end_user(mock_upload_file):
scope = FileAccessScope(
tenant_id=TEST_TENANT_ID,
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
build_from_mapping(mapping=local_file_mapping(), tenant_id=TEST_TENANT_ID)
stmt = mock_upload_file.call_args.args[0]
whereclause = str(stmt.whereclause)
assert "upload_files.created_by_role" in whereclause
assert "upload_files.created_by" in whereclause
def test_build_from_mapping_scopes_tool_file_to_end_user():
tool_file = MagicMock(spec=ToolFile)
tool_file.id = TEST_TOOL_FILE_ID
tool_file.tenant_id = TEST_TENANT_ID
tool_file.name = "tool_file.pdf"
tool_file.file_key = "tool_file.pdf"
tool_file.mimetype = "application/pdf"
tool_file.original_url = "http://example.com/tool.pdf"
tool_file.size = 2048
scope = FileAccessScope(
tenant_id=TEST_TENANT_ID,
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with patch("factories.file_factory.db.session.scalar", return_value=tool_file, autospec=True) as scalar:
with bind_file_access_scope(scope):
build_from_mapping(mapping=tool_file_mapping(), tenant_id=TEST_TENANT_ID)
stmt = scalar.call_args.args[0]
whereclause = str(stmt.whereclause)
assert "tool_files.user_id" in whereclause
def test_disallowed_file_types(mock_upload_file):
"""Test that disallowed file types are rejected."""
# Config that only allows image and document types

View File

@ -111,8 +111,8 @@ def test_inputs_resolve_owner_tenant_for_single_file_mapping(
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
_ = config, strict_type_validation
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
_ = config, strict_type_validation, access_controller
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
@ -145,8 +145,8 @@ def test_inputs_resolve_owner_tenant_for_file_list_mapping(
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
_ = config, strict_type_validation
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
_ = config, strict_type_validation, access_controller
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
@ -196,8 +196,8 @@ def test_inputs_prefer_serialized_tenant_id_when_present(
monkeypatch.setattr(model_module.db.session, "scalar", fail_if_called)
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
_ = config, strict_type_validation
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
_ = config, strict_type_validation, access_controller
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)