mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: add file access controller
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
|
||||
mappings=files,
|
||||
tenant_id=workflow.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
return file_objs
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.variables.segment_group import SegmentGroup
|
||||
@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@ -389,13 +391,21 @@ class VariableApi(Resource):
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@ -21,6 +21,7 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource):
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -46,6 +47,7 @@ from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
@ -524,7 +526,10 @@ class BaseAgentRunner(AppRunner):
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||
message_files=files,
|
||||
tenant_id=self.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
@ -147,85 +147,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
@ -457,94 +459,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = conversation is None
|
||||
with self._bind_file_access_scope(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
):
|
||||
is_first_conversation = conversation is None
|
||||
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# get conversation dialogue count
|
||||
# NOTE: dialogue_count should not start from 0,
|
||||
# because during the first conversation, dialogue_count should be 1.
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
|
||||
# get conversation dialogue count
|
||||
# NOTE: dialogue_count should not start from 0,
|
||||
# because during the first conversation, dialogue_count should be 1.
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
# message_ = session.get(Message, message.id)
|
||||
# assert message_ is not None
|
||||
# message = message_
|
||||
# db.session.refresh(workflow)
|
||||
# db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
db.session.close()
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
override_model_config_dict or app_model_config.to_dict()
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
# convert to app config
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@ -8,7 +9,8 @@ from core.app.apps.draft_variable_saver import (
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.variables.input_entities import VariableEntityType
|
||||
@ -57,6 +59,31 @@ class _DebuggerDraftVariableSaver:
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
_file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController()
|
||||
|
||||
@staticmethod
|
||||
def _bind_file_access_scope(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user: Account | EndUser,
|
||||
invoke_from: InvokeFrom,
|
||||
) -> AbstractContextManager[None]:
|
||||
"""Bind request-scoped file ownership markers for downstream file lookups."""
|
||||
|
||||
user_id = getattr(user, "id", None)
|
||||
if not isinstance(user_id, str) or not user_id:
|
||||
return nullcontext()
|
||||
|
||||
user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER
|
||||
return bind_file_access_scope(
|
||||
FileAccessScope(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
def _prepare_user_inputs(
|
||||
self,
|
||||
*,
|
||||
@ -85,6 +112,7 @@ class BaseAppGenerator:
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
strict_type_validation=strict_type_validation,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||
@ -99,6 +127,7 @@ class BaseAppGenerator:
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
override_model_config_dict or app_model_config.to_dict()
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
# convert to app config
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=streaming,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
override_model_config_dict or app_model_config.to_dict()
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
# parse files
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@ -128,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
@ -291,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
"""
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
with self._bind_file_access_scope(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
):
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.close()
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.close()
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
|
||||
11
api/core/app/file_access/__init__.py
Normal file
11
api/core/app/file_access/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .controller import DatabaseFileAccessController
|
||||
from .protocols import FileAccessControllerProtocol
|
||||
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
|
||||
|
||||
__all__ = [
|
||||
"DatabaseFileAccessController",
|
||||
"FileAccessControllerProtocol",
|
||||
"FileAccessScope",
|
||||
"bind_file_access_scope",
|
||||
"get_current_file_access_scope",
|
||||
]
|
||||
103
api/core/app/file_access/controller.py
Normal file
103
api/core/app/file_access/controller.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
from .protocols import FileAccessControllerProtocol
|
||||
from .scope import FileAccessScope, get_current_file_access_scope
|
||||
|
||||
|
||||
class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
"""Workflow-layer authorization helper for database-backed file lookups.
|
||||
|
||||
Tenant scoping remains mandatory. When the current execution belongs to an
|
||||
end user, the lookup is additionally constrained to that end user's file
|
||||
ownership markers.
|
||||
"""
|
||||
|
||||
_scope_getter: Callable[[], FileAccessScope | None]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope,
|
||||
) -> None:
|
||||
self._scope_getter = scope_getter
|
||||
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
return self._scope_getter()
|
||||
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[UploadFile]]:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return stmt
|
||||
|
||||
scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id)
|
||||
if not resolved_scope.requires_user_ownership:
|
||||
return scoped_stmt
|
||||
|
||||
return scoped_stmt.where(
|
||||
UploadFile.created_by_role == CreatorUserRole.END_USER,
|
||||
UploadFile.created_by == resolved_scope.user_id,
|
||||
)
|
||||
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[ToolFile]]:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return stmt
|
||||
|
||||
scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id)
|
||||
if not resolved_scope.requires_user_ownership:
|
||||
return scoped_stmt
|
||||
|
||||
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
|
||||
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> UploadFile | None:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return session.get(UploadFile, file_id)
|
||||
|
||||
stmt = self.apply_upload_file_filters(
|
||||
select(UploadFile).where(UploadFile.id == file_id),
|
||||
scope=resolved_scope,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> ToolFile | None:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return session.get(ToolFile, file_id)
|
||||
|
||||
stmt = self.apply_tool_file_filters(
|
||||
select(ToolFile).where(ToolFile.id == file_id),
|
||||
scope=resolved_scope,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
81
api/core/app/file_access/protocols.py
Normal file
81
api/core/app/file_access/protocols.py
Normal file
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .scope import FileAccessScope
|
||||
|
||||
|
||||
class FileAccessControllerProtocol(Protocol):
|
||||
"""Contract for applying access rules to file lookups.
|
||||
|
||||
Implementations translate an optional execution scope into query constraints
|
||||
and authorized record retrieval. The contract is intentionally limited to
|
||||
ownership and tenancy rules for workflow-layer file access.
|
||||
"""
|
||||
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
"""Return the scope active for the current execution, if one exists.
|
||||
|
||||
Callers use this to decide whether embedded file metadata may be trusted
|
||||
or whether a fresh authorized lookup is required.
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[UploadFile]]:
|
||||
"""Return an upload-file query constrained by the supplied access scope.
|
||||
|
||||
The returned statement must preserve the caller's existing predicates and
|
||||
append only access-control conditions.
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[ToolFile]]:
|
||||
"""Return a tool-file query constrained by the supplied access scope.
|
||||
|
||||
The returned statement must preserve the caller's existing predicates and
|
||||
append only access-control conditions.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> UploadFile | None:
|
||||
"""Load one authorized upload-file record for the given identifier.
|
||||
|
||||
Returns ``None`` when the file does not exist or when the scope does not
|
||||
permit access to that record.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> ToolFile | None:
|
||||
"""Load one authorized tool-file record for the given identifier.
|
||||
|
||||
Returns ``None`` when the file does not exist or when the scope does not
|
||||
permit access to that record.
|
||||
"""
|
||||
...
|
||||
40
api/core/app/file_access/scope.py
Normal file
40
api/core/app/file_access/scope.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
|
||||
_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
|
||||
"current_file_access_scope",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileAccessScope:
|
||||
"""Request-scoped ownership context used by workflow-layer file lookups."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
@property
|
||||
def requires_user_ownership(self) -> bool:
|
||||
return self.user_from == UserFrom.END_USER
|
||||
|
||||
|
||||
def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
return _current_file_access_scope.get()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_file_access_scope.reset(token)
|
||||
@ -10,6 +10,7 @@ from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
@ -18,14 +19,22 @@ from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from dify_graph.file.runtime import set_workflow_file_runtime
|
||||
from extensions.ext_storage import storage
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
"""Production runtime wiring for ``dify_graph.file``."""
|
||||
"""Production runtime wiring for ``dify_graph.file``.
|
||||
|
||||
When a request-scoped file access scope is present, opaque file references are
|
||||
re-validated against the database before URLs are signed or storage keys are used.
|
||||
"""
|
||||
|
||||
_file_access_controller: FileAccessControllerProtocol
|
||||
|
||||
def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None:
|
||||
self._file_access_controller = file_access_controller
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str:
|
||||
@ -55,7 +64,16 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
upload_file_id=parsed_reference.record_id,
|
||||
for_external=for_external,
|
||||
)
|
||||
if file.transfer_method in {FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE}:
|
||||
if file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
|
||||
if file.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
self._assert_upload_file_access(upload_file_id=parsed_reference.record_id)
|
||||
return sign_tool_file(
|
||||
tool_file_id=parsed_reference.record_id,
|
||||
extension=file.extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
if file.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return self.resolve_tool_file_url(
|
||||
@ -72,6 +90,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
as_attachment: bool = False,
|
||||
for_external: bool = True,
|
||||
) -> str:
|
||||
self._assert_upload_file_access(upload_file_id=upload_file_id)
|
||||
base_url = self._base_url(for_external=for_external)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
query = self._sign_query(payload=f"file-preview|{upload_file_id}")
|
||||
@ -80,6 +99,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
return f"{url}?{urllib.parse.urlencode(query)}"
|
||||
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._assert_tool_file_access(tool_file_id=tool_file_id)
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
def verify_preview_signature(
|
||||
@ -121,7 +141,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("Missing file reference")
|
||||
if parsed_reference.storage_key:
|
||||
if parsed_reference.storage_key and self._file_access_controller.current_scope() is None:
|
||||
return parsed_reference.storage_key
|
||||
|
||||
record_id = parsed_reference.record_id
|
||||
@ -131,16 +151,34 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
}:
|
||||
upload_file = session.get(UploadFile, record_id)
|
||||
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id)
|
||||
if upload_file is None:
|
||||
raise ValueError(f"Upload file {record_id} not found")
|
||||
return upload_file.key
|
||||
|
||||
tool_file = session.get(ToolFile, record_id)
|
||||
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"Tool file {record_id} not found")
|
||||
return tool_file.file_key
|
||||
|
||||
def _assert_upload_file_access(self, *, upload_file_id: str) -> None:
|
||||
if self._file_access_controller.current_scope() is None:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id)
|
||||
if upload_file is None:
|
||||
raise ValueError(f"Upload file {upload_file_id} not found")
|
||||
|
||||
def _assert_tool_file_access(self, *, tool_file_id: str) -> None:
|
||||
if self._file_access_controller.current_scope() is None:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"Tool file {tool_file_id} not found")
|
||||
|
||||
|
||||
def bind_dify_workflow_file_runtime() -> None:
|
||||
set_workflow_file_runtime(DifyWorkflowFileRuntime())
|
||||
set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()))
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@ -37,6 +38,7 @@ from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
@ -284,7 +286,11 @@ class DatasourceManager:
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"url": url,
|
||||
}
|
||||
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
file_out = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif mtype == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
||||
|
||||
@ -4,6 +4,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from dify_graph.file import file_manager
|
||||
@ -23,6 +24,8 @@ from models.workflow import Workflow
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
def __init__(
|
||||
@ -85,7 +88,10 @@ class TokenBufferMemory:
|
||||
# Build files directly without filtering by belongs_to
|
||||
file_objs = [
|
||||
file_factory.build_from_message_file(
|
||||
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
message_file=message_file,
|
||||
tenant_id=app_record.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
for message_file in message_files
|
||||
]
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
@ -49,6 +50,8 @@ from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
@ -556,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
file_obj = build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -26,6 +27,7 @@ from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
@ -331,6 +333,7 @@ class WorkflowTool(Tool):
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
@ -338,6 +341,7 @@ class WorkflowTool(Tool):
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
@ -82,6 +83,9 @@ if TYPE_CHECKING:
|
||||
from dify_graph.nodes.tool.entities import ToolNodeData
|
||||
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext:
|
||||
if isinstance(run_context, DifyRunContext):
|
||||
return run_context
|
||||
@ -127,6 +131,7 @@ class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self._run_context.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -27,6 +28,8 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class AgentMessageTransformer:
|
||||
def transform(
|
||||
@ -93,6 +96,7 @@ class AgentMessageTransformer:
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
@ -116,6 +120,7 @@ class AgentMessageTransformer:
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
|
||||
@ -7,6 +7,7 @@ from configs import dify_config
|
||||
from context import capture_current_context
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class
|
||||
@ -35,6 +36,7 @@ from factories import file_factory
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@ -508,13 +510,21 @@ class WorkflowEntry:
|
||||
continue
|
||||
|
||||
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
|
||||
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
|
||||
input_value = file_factory.build_from_mapping(
|
||||
mapping=input_value,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
if (
|
||||
isinstance(input_value, list)
|
||||
and all(isinstance(item, dict) for item in input_value)
|
||||
and all("type" in item and "transfer_method" in item for item in input_value)
|
||||
):
|
||||
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
|
||||
input_value = file_factory.build_from_mappings(
|
||||
mappings=input_value,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
|
||||
# append variable and value to variable pool
|
||||
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
|
||||
@ -1,582 +0,0 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.http import parse_options_header
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id
|
||||
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
from dify_graph.file.file_factory import standardize_file_type
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None:
|
||||
for key in (*keys, "reference", "related_id"):
|
||||
raw_value = mapping.get(key)
|
||||
if isinstance(raw_value, str) and raw_value:
|
||||
resolved_value = resolve_file_record_id(raw_value)
|
||||
if resolved_value:
|
||||
return resolved_value
|
||||
return None
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
*,
|
||||
message_files: Sequence["MessageFile"],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
) -> Sequence[File]:
|
||||
results = [
|
||||
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
|
||||
for file in message_files
|
||||
if file.belongs_to != FileBelongsTo.ASSISTANT
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def build_from_message_file(
|
||||
*,
|
||||
message_file: "MessageFile",
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None,
|
||||
):
|
||||
mapping = {
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
"type": message_file.type,
|
||||
}
|
||||
|
||||
# Only include id if it exists (message_file has been committed to DB)
|
||||
if message_file.id:
|
||||
mapping["id"] = message_file.id
|
||||
|
||||
# Set the correct ID field based on transfer method
|
||||
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
mapping["tool_file_id"] = message_file.upload_file_id
|
||||
else:
|
||||
mapping["upload_file_id"] = message_file.upload_file_id
|
||||
|
||||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def build_from_mapping(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if not transfer_method_value:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
|
||||
build_functions: dict[FileTransferMethod, Callable] = {
|
||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
|
||||
}
|
||||
|
||||
build_func = build_functions.get(transfer_method)
|
||||
if not build_func:
|
||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||
|
||||
file: File = build_func(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
transfer_method=transfer_method,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
if config and not _is_file_valid_with_config(
|
||||
input_file_type=mapping.get("type", FileType.CUSTOM),
|
||||
file_extension=file.extension or "",
|
||||
file_transfer_method=file.transfer_method,
|
||||
config=config,
|
||||
):
|
||||
raise ValueError(f"File validation failed for file: {file.filename}")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def build_from_mappings(
|
||||
*,
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileUploadConfig | None = None,
|
||||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
) -> Sequence[File]:
|
||||
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
||||
# Implement batch processing to reduce database load when handling multiple files.
|
||||
# Filter out None/empty mappings to avoid errors
|
||||
def is_valid_mapping(m: Mapping[str, Any]) -> bool:
|
||||
if not m or not m.get("transfer_method"):
|
||||
return False
|
||||
# For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
|
||||
transfer_method = m.get("transfer_method")
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = m.get("url") or m.get("remote_url")
|
||||
if not url:
|
||||
return False
|
||||
return True
|
||||
|
||||
valid_mappings = [m for m in mappings if is_valid_mapping(m)]
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
for mapping in valid_mappings
|
||||
]
|
||||
|
||||
if (
|
||||
config
|
||||
# If image config is set.
|
||||
and config.image_config
|
||||
# And the number of image files exceeds the maximum limit
|
||||
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
|
||||
):
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
|
||||
if config and config.number_limits and len(files) > config.number_limits:
|
||||
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _build_from_local_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
upload_file_id = _resolve_mapping_file_id(mapping, "upload_file_id")
|
||||
if not upload_file_id:
|
||||
raise ValueError("Invalid upload file id")
|
||||
# check if upload_file_id is a valid uuid
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid upload file id format")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
row = db.session.scalar(stmt)
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
specified_type = mapping.get("type", "custom")
|
||||
|
||||
if strict_type_validation and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
reference=build_file_reference(record_id=str(row.id), storage_key=row.key),
|
||||
size=row.size,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_remote_url(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
upload_file_id = _resolve_mapping_file_id(mapping, "upload_file_id")
|
||||
if upload_file_id:
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid upload file id format")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
upload_file = db.session.scalar(stmt)
|
||||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = standardize_file_type(
|
||||
extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||
)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
reference=build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key),
|
||||
size=upload_file.size,
|
||||
)
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
|
||||
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
|
||||
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
|
||||
filename: str | None = None
|
||||
# Try to extract from Content-Disposition header first
|
||||
if content_disposition:
|
||||
# Manually extract filename* parameter since parse_options_header doesn't support it
|
||||
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
|
||||
if filename_star_match:
|
||||
raw_star = filename_star_match.group(1).strip()
|
||||
# Remove trailing quotes if present
|
||||
raw_star = raw_star.removesuffix('"')
|
||||
# format: charset'lang'value
|
||||
try:
|
||||
parts = raw_star.split("'", 2)
|
||||
charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
|
||||
value = parts[2] if len(parts) == 3 else parts[-1]
|
||||
filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
|
||||
except Exception:
|
||||
# Fallback: try to extract value after the last single quote
|
||||
if "''" in raw_star:
|
||||
filename = urllib.parse.unquote(raw_star.split("''")[-1])
|
||||
else:
|
||||
filename = urllib.parse.unquote(raw_star)
|
||||
|
||||
if not filename:
|
||||
# Fallback to regular filename parameter
|
||||
_, params = parse_options_header(content_disposition)
|
||||
raw = params.get("filename")
|
||||
if raw:
|
||||
# Strip surrounding quotes and percent-decode if present
|
||||
if len(raw) >= 2 and raw[0] == raw[-1] == '"':
|
||||
raw = raw[1:-1]
|
||||
filename = urllib.parse.unquote(raw)
|
||||
# Fallback to URL path if no filename from header
|
||||
if not filename:
|
||||
candidate = os.path.basename(url_path)
|
||||
filename = urllib.parse.unquote(candidate) if candidate else None
|
||||
# Defense-in-depth: ensure basename only
|
||||
if filename:
|
||||
filename = os.path.basename(filename)
|
||||
# Return None if filename is empty or only whitespace
|
||||
if not filename or not filename.strip():
|
||||
filename = None
|
||||
return filename or None
|
||||
|
||||
|
||||
def _guess_mime_type(filename: str) -> str:
|
||||
"""Guess MIME type from filename, returning empty string if None."""
|
||||
guessed_mime, _ = mimetypes.guess_type(filename)
|
||||
return guessed_mime or ""
|
||||
|
||||
|
||||
def _get_remote_file_info(url: str):
|
||||
file_size = -1
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
|
||||
# Initialize mime_type from filename as fallback
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
content_disposition = resp.headers.get("Content-Disposition")
|
||||
extracted_filename = _extract_filename(url_path, content_disposition)
|
||||
if extracted_filename:
|
||||
filename = extracted_filename
|
||||
mime_type = _guess_mime_type(filename)
|
||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||
# Fallback to Content-Type header if mime_type is still empty
|
||||
if not mime_type:
|
||||
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
|
||||
if not filename:
|
||||
extension = mimetypes.guess_extension(mime_type) or ".bin"
|
||||
filename = f"{uuid.uuid4().hex}{extension}"
|
||||
if not mime_type:
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
return mime_type, filename, file_size
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
tool_file_id = _resolve_mapping_file_id(mapping, "tool_file_id")
|
||||
|
||||
if not tool_file_id:
|
||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||
tool_file = db.session.scalar(
|
||||
select(ToolFile).where(
|
||||
ToolFile.id == tool_file_id,
|
||||
ToolFile.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=tool_file.name,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id), storage_key=tool_file.file_key),
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_datasource_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
datasource_file_id = _resolve_mapping_file_id(mapping, "datasource_file_id")
|
||||
if not datasource_file_id:
|
||||
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
|
||||
datasource_file = db.session.scalar(
|
||||
select(UploadFile).where(
|
||||
UploadFile.id == datasource_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if datasource_file is None:
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
|
||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
reference=build_file_reference(record_id=str(datasource_file.id), storage_key=datasource_file.key),
|
||||
extension=extension,
|
||||
mime_type=datasource_file.mime_type,
|
||||
size=datasource_file.size,
|
||||
url=datasource_file.source_url,
|
||||
)
|
||||
|
||||
|
||||
def _is_file_valid_with_config(
|
||||
*,
|
||||
input_file_type: str,
|
||||
file_extension: str,
|
||||
file_transfer_method: FileTransferMethod,
|
||||
config: FileUploadConfig,
|
||||
) -> bool:
|
||||
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
|
||||
# These are internally generated and should bypass user upload restrictions
|
||||
if file_transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return True
|
||||
|
||||
if (
|
||||
config.allowed_file_types
|
||||
and input_file_type not in config.allowed_file_types
|
||||
and input_file_type != FileType.CUSTOM
|
||||
):
|
||||
return False
|
||||
|
||||
if (
|
||||
input_file_type == FileType.CUSTOM
|
||||
and config.allowed_file_extensions is not None
|
||||
and file_extension not in config.allowed_file_extensions
|
||||
):
|
||||
return False
|
||||
|
||||
if input_file_type == FileType.IMAGE:
|
||||
if (
|
||||
config.image_config
|
||||
and config.image_config.transfer_methods
|
||||
and file_transfer_method not in config.image_config.transfer_methods
|
||||
):
|
||||
return False
|
||||
elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class StorageKeyLoader:
|
||||
"""FileKeyLoader load the storage key from database for a list of files.
|
||||
This loader is batched, the database query count is constant regardless of the input size.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session, tenant_id: str):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id.in_(upload_file_ids),
|
||||
UploadFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
|
||||
|
||||
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id.in_(tool_file_ids),
|
||||
ToolFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
|
||||
|
||||
def load_storage_keys(self, files: Sequence[File]):
|
||||
"""Loads storage keys for a sequence of files by retrieving the corresponding
|
||||
`UploadFile` or `ToolFile` records from the database based on their transfer method.
|
||||
|
||||
This method doesn't modify the input sequence structure but updates the `_storage_key`
|
||||
property of each file object by extracting the relevant key from its database record.
|
||||
Tenant scoping is enforced by this loader's context rather than by embedding tenant
|
||||
identity inside graph-layer ``File`` values.
|
||||
|
||||
Performance note: This is a batched operation where database query count remains constant
|
||||
regardless of input size. However, for optimal performance, input sequences should contain
|
||||
fewer than 1000 files. For larger collections, split into smaller batches and process each
|
||||
batch separately.
|
||||
"""
|
||||
|
||||
upload_file_ids: list[uuid.UUID] = []
|
||||
tool_file_ids: list[uuid.UUID] = []
|
||||
for file in files:
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("file id should not be None.")
|
||||
model_id = uuid.UUID(parsed_reference.record_id)
|
||||
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_ids.append(model_id)
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_ids.append(model_id)
|
||||
|
||||
tool_files = self._load_tool_files(tool_file_ids)
|
||||
upload_files = self._load_upload_files(upload_file_ids)
|
||||
for file in files:
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("file id should not be None.")
|
||||
model_id = uuid.UUID(parsed_reference.record_id)
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file.reference = build_file_reference(
|
||||
record_id=str(upload_file_row.id),
|
||||
storage_key=upload_file_row.key,
|
||||
)
|
||||
file.storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file.reference = build_file_reference(
|
||||
record_id=str(tool_file_row.id),
|
||||
storage_key=tool_file_row.file_key,
|
||||
)
|
||||
file.storage_key = tool_file_row.file_key
|
||||
31
api/factories/file_factory/__init__.py
Normal file
31
api/factories/file_factory/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Workflow file factory package.
|
||||
|
||||
This package normalizes workflow-layer file payloads into graph-layer ``File``
|
||||
values. It keeps tenancy and ownership checks in the application layer and
|
||||
preserves the historical ``factories.file_factory`` import surface for callers.
|
||||
"""
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from extensions.ext_database import db
|
||||
|
||||
from .builders import build_from_mapping, build_from_mappings
|
||||
from .message_files import build_from_message_file, build_from_message_files
|
||||
from .remote import _extract_filename, _get_remote_file_info
|
||||
from .storage_keys import StorageKeyLoader
|
||||
|
||||
__all__ = [
|
||||
"File",
|
||||
"FileTransferMethod",
|
||||
"FileType",
|
||||
"FileUploadConfig",
|
||||
"StorageKeyLoader",
|
||||
"_extract_filename",
|
||||
"_get_remote_file_info",
|
||||
"build_from_mapping",
|
||||
"build_from_mappings",
|
||||
"build_from_message_file",
|
||||
"build_from_message_files",
|
||||
"db",
|
||||
"ssrf_proxy",
|
||||
]
|
||||
325
api/factories/file_factory/builders.py
Normal file
325
api/factories/file_factory/builders.py
Normal file
@ -0,0 +1,325 @@
|
||||
"""Core builders for workflow file mappings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
from dify_graph.file.file_factory import standardize_file_type
|
||||
from extensions.ext_database import db
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .common import resolve_mapping_file_id
|
||||
from .remote import _get_remote_file_info
|
||||
from .validation import is_file_valid_with_config
|
||||
|
||||
|
||||
def build_from_mapping(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
strict_type_validation: bool = False,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> File:
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if not transfer_method_value:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
build_func = _get_build_function(transfer_method)
|
||||
file = build_func(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
transfer_method=transfer_method,
|
||||
strict_type_validation=strict_type_validation,
|
||||
access_controller=access_controller,
|
||||
)
|
||||
|
||||
if config and not is_file_valid_with_config(
|
||||
input_file_type=mapping.get("type", FileType.CUSTOM),
|
||||
file_extension=file.extension or "",
|
||||
file_transfer_method=file.transfer_method,
|
||||
config=config,
|
||||
):
|
||||
raise ValueError(f"File validation failed for file: {file.filename}")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def build_from_mappings(
|
||||
*,
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileUploadConfig | None = None,
|
||||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> Sequence[File]:
|
||||
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
||||
# Implement batch processing to reduce database load when handling multiple files.
|
||||
valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)]
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
strict_type_validation=strict_type_validation,
|
||||
access_controller=access_controller,
|
||||
)
|
||||
for mapping in valid_mappings
|
||||
]
|
||||
|
||||
if (
|
||||
config
|
||||
and config.image_config
|
||||
and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits
|
||||
):
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
|
||||
if config and config.number_limits and len(files) > config.number_limits:
|
||||
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _get_build_function(transfer_method: FileTransferMethod):
|
||||
build_functions = {
|
||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
|
||||
}
|
||||
build_func = build_functions.get(transfer_method)
|
||||
if build_func is None:
|
||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||
return build_func
|
||||
|
||||
|
||||
def _resolve_file_type(
|
||||
*,
|
||||
detected_file_type: FileType,
|
||||
specified_type: str | None,
|
||||
strict_type_validation: bool,
|
||||
) -> FileType:
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
return FileType(specified_type)
|
||||
return detected_file_type
|
||||
|
||||
|
||||
def _build_from_local_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> File:
|
||||
upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id")
|
||||
if not upload_file_id:
|
||||
raise ValueError("Invalid upload file id")
|
||||
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid upload file id format") from exc
|
||||
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type", "custom"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
reference=build_file_reference(record_id=str(row.id), storage_key=row.key),
|
||||
size=row.size,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_remote_url(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> File:
|
||||
upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id")
|
||||
if upload_file_id:
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid upload file id format") from exc
|
||||
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = standardize_file_type(
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
reference=build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key),
|
||||
size=upload_file.size,
|
||||
)
|
||||
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> File:
|
||||
tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id")
|
||||
if not tool_file_id:
|
||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id == tool_file_id,
|
||||
ToolFile.tenant_id == tenant_id,
|
||||
)
|
||||
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=tool_file.name,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id), storage_key=tool_file.file_key),
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_datasource_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> File:
|
||||
datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id")
|
||||
if not datasource_file_id:
|
||||
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
|
||||
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == datasource_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if datasource_file is None:
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
reference=build_file_reference(record_id=str(datasource_file.id), storage_key=datasource_file.key),
|
||||
extension=extension,
|
||||
mime_type=datasource_file.mime_type,
|
||||
size=datasource_file.size,
|
||||
url=datasource_file.source_url,
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
|
||||
if not mapping or not mapping.get("transfer_method"):
|
||||
return False
|
||||
|
||||
if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL:
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
return False
|
||||
|
||||
return True
|
||||
27
api/factories/file_factory/common.py
Normal file
27
api/factories/file_factory/common.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Shared helpers for workflow file factory modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
|
||||
|
||||
def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None:
|
||||
"""Resolve historical file identifiers from persisted mapping payloads.
|
||||
|
||||
Workflow and model payloads can outlive file schema changes. Older rows may
|
||||
still carry concrete identifiers in legacy fields such as ``related_id``,
|
||||
while newer payloads use opaque references. Keep this compatibility lookup in
|
||||
the factory layer so historical data remains readable without reintroducing
|
||||
storage details into graph-layer ``File`` values.
|
||||
"""
|
||||
|
||||
for key in (*keys, "reference", "related_id"):
|
||||
raw_value = mapping.get(key)
|
||||
if isinstance(raw_value, str) and raw_value:
|
||||
resolved_value = resolve_file_record_id(raw_value)
|
||||
if resolved_value:
|
||||
return resolved_value
|
||||
return None
|
||||
59
api/factories/file_factory/message_files.py
Normal file
59
api/factories/file_factory/message_files.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Adapters from persisted message files to graph-layer file values."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig
|
||||
from models import MessageFile
|
||||
|
||||
from .builders import build_from_mapping
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
*,
|
||||
message_files: Sequence[MessageFile],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> Sequence[File]:
|
||||
return [
|
||||
build_from_message_file(
|
||||
message_file=message_file,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
access_controller=access_controller,
|
||||
)
|
||||
for message_file in message_files
|
||||
if message_file.belongs_to != FileBelongsTo.ASSISTANT
|
||||
]
|
||||
|
||||
|
||||
def build_from_message_file(
|
||||
*,
|
||||
message_file: MessageFile,
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> File:
|
||||
mapping = {
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
"type": message_file.type,
|
||||
}
|
||||
|
||||
if message_file.id:
|
||||
mapping["id"] = message_file.id
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
mapping["tool_file_id"] = message_file.upload_file_id
|
||||
else:
|
||||
mapping["upload_file_id"] = message_file.upload_file_id
|
||||
|
||||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
access_controller=access_controller,
|
||||
)
|
||||
84
api/factories/file_factory/remote.py
Normal file
84
api/factories/file_factory/remote.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""Remote file metadata helpers used by workflow file normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from werkzeug.http import parse_options_header
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
|
||||
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
|
||||
filename: str | None = None
|
||||
if content_disposition:
|
||||
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
|
||||
if filename_star_match:
|
||||
raw_star = filename_star_match.group(1).strip()
|
||||
raw_star = raw_star.removesuffix('"')
|
||||
try:
|
||||
parts = raw_star.split("'", 2)
|
||||
charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
|
||||
value = parts[2] if len(parts) == 3 else parts[-1]
|
||||
filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
|
||||
except Exception:
|
||||
if "''" in raw_star:
|
||||
filename = urllib.parse.unquote(raw_star.split("''")[-1])
|
||||
else:
|
||||
filename = urllib.parse.unquote(raw_star)
|
||||
|
||||
if not filename:
|
||||
_, params = parse_options_header(content_disposition)
|
||||
raw = params.get("filename")
|
||||
if raw:
|
||||
if len(raw) >= 2 and raw[0] == raw[-1] == '"':
|
||||
raw = raw[1:-1]
|
||||
filename = urllib.parse.unquote(raw)
|
||||
|
||||
if not filename:
|
||||
candidate = os.path.basename(url_path)
|
||||
filename = urllib.parse.unquote(candidate) if candidate else None
|
||||
|
||||
if filename:
|
||||
filename = os.path.basename(filename)
|
||||
if not filename or not filename.strip():
|
||||
filename = None
|
||||
|
||||
return filename or None
|
||||
|
||||
|
||||
def _guess_mime_type(filename: str) -> str:
|
||||
guessed_mime, _ = mimetypes.guess_type(filename)
|
||||
return guessed_mime or ""
|
||||
|
||||
|
||||
def _get_remote_file_info(url: str) -> tuple[str, str, int]:
|
||||
file_size = -1
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
content_disposition = resp.headers.get("Content-Disposition")
|
||||
extracted_filename = _extract_filename(url_path, content_disposition)
|
||||
if extracted_filename:
|
||||
filename = extracted_filename
|
||||
mime_type = _guess_mime_type(filename)
|
||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||
if not mime_type:
|
||||
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
|
||||
if not filename:
|
||||
extension = mimetypes.guess_extension(mime_type) or ".bin"
|
||||
filename = f"{uuid.uuid4().hex}{extension}"
|
||||
if not mime_type:
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
return mime_type, filename, file_size
|
||||
107
api/factories/file_factory/storage_keys.py
Normal file
107
api/factories/file_factory/storage_keys.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Batched storage-key hydration for workflow files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from core.workflow.file_reference import build_file_reference, parse_file_reference
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
|
||||
class StorageKeyLoader:
|
||||
"""Load storage keys for files with a constant number of database queries."""
|
||||
|
||||
_session: Session
|
||||
_tenant_id: str
|
||||
_access_controller: FileAccessControllerProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
access_controller: FileAccessControllerProtocol,
|
||||
) -> None:
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._access_controller = access_controller
|
||||
|
||||
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id.in_(upload_file_ids),
|
||||
UploadFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
scoped_stmt = self._access_controller.apply_upload_file_filters(stmt)
|
||||
return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)}
|
||||
|
||||
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id.in_(tool_file_ids),
|
||||
ToolFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
scoped_stmt = self._access_controller.apply_tool_file_filters(stmt)
|
||||
return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)}
|
||||
|
||||
def load_storage_keys(self, files: Sequence[File]) -> None:
|
||||
"""Hydrate storage keys by loading their backing file rows in batches.
|
||||
|
||||
The sequence shape is preserved. Each file is updated in place with a
|
||||
canonical reference and storage key loaded from an authorized database
|
||||
row. Tenant scoping is enforced by this loader's context rather than by
|
||||
embedding tenant identity inside graph-layer ``File`` values.
|
||||
|
||||
For best performance, prefer batches smaller than 1000 files.
|
||||
"""
|
||||
|
||||
upload_file_ids: list[uuid.UUID] = []
|
||||
tool_file_ids: list[uuid.UUID] = []
|
||||
for file in files:
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("file id should not be None.")
|
||||
|
||||
model_id = uuid.UUID(parsed_reference.record_id)
|
||||
if file.transfer_method in (
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
upload_file_ids.append(model_id)
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_ids.append(model_id)
|
||||
|
||||
tool_files = self._load_tool_files(tool_file_ids)
|
||||
upload_files = self._load_upload_files(upload_file_ids)
|
||||
for file in files:
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("file id should not be None.")
|
||||
|
||||
model_id = uuid.UUID(parsed_reference.record_id)
|
||||
if file.transfer_method in (
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file.reference = build_file_reference(
|
||||
record_id=str(upload_file_row.id),
|
||||
storage_key=upload_file_row.key,
|
||||
)
|
||||
file.storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file.reference = build_file_reference(
|
||||
record_id=str(tool_file_row.id),
|
||||
storage_key=tool_file_row.file_key,
|
||||
)
|
||||
file.storage_key = tool_file_row.file_key
|
||||
44
api/factories/file_factory/validation.py
Normal file
44
api/factories/file_factory/validation.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Validation helpers for workflow file inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dify_graph.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
|
||||
|
||||
def is_file_valid_with_config(
|
||||
*,
|
||||
input_file_type: str,
|
||||
file_extension: str,
|
||||
file_transfer_method: FileTransferMethod,
|
||||
config: FileUploadConfig,
|
||||
) -> bool:
|
||||
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
|
||||
# These are internally generated and should bypass user upload restrictions
|
||||
if file_transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return True
|
||||
|
||||
if (
|
||||
config.allowed_file_types
|
||||
and input_file_type not in config.allowed_file_types
|
||||
and input_file_type != FileType.CUSTOM
|
||||
):
|
||||
return False
|
||||
|
||||
if (
|
||||
input_file_type == FileType.CUSTOM
|
||||
and config.allowed_file_extensions is not None
|
||||
and file_extension not in config.allowed_file_extensions
|
||||
):
|
||||
return False
|
||||
|
||||
if input_file_type == FileType.IMAGE:
|
||||
if (
|
||||
config.image_config
|
||||
and config.image_config.transfer_methods
|
||||
and file_transfer_method not in config.image_config.transfer_methods
|
||||
):
|
||||
return False
|
||||
elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -7,6 +7,7 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@ -53,6 +54,13 @@ if TYPE_CHECKING:
|
||||
# --- TypedDict definitions for structured dict return types ---
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_file_access_controller():
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
|
||||
return DatabaseFileAccessController()
|
||||
|
||||
|
||||
def _resolve_app_tenant_id(app_id: str) -> str:
|
||||
resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id))
|
||||
if not resolved_tenant_id:
|
||||
@ -1618,6 +1626,7 @@ class Message(Base):
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if message_file.url is None:
|
||||
@ -1631,6 +1640,7 @@ class Message(Base):
|
||||
"url": message_file.url,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
if message_file.upload_file_id is None:
|
||||
@ -1645,6 +1655,7 @@ class Message(Base):
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@ -1,12 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_file_access_controller():
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
|
||||
return DatabaseFileAccessController()
|
||||
|
||||
|
||||
def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None:
|
||||
reference = file_mapping.get("reference")
|
||||
if isinstance(reference, str) and reference:
|
||||
@ -61,4 +69,8 @@ def build_file_from_input_mapping(
|
||||
mapping["upload_file_id"] = record_id
|
||||
|
||||
tenant_id = resolve_file_mapping_tenant_id(file_mapping=mapping, tenant_resolver=tenant_resolver)
|
||||
return file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
@ -46,6 +47,7 @@ except ImportError:
|
||||
magic = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class WebhookService:
|
||||
@ -422,6 +424,7 @@ class WebhookService:
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -14,6 +14,7 @@ from sqlalchemy.sql.expression import and_, or_
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.trigger.constants import is_trigger_node_type
|
||||
from core.workflow.system_variables import SystemVariableKey
|
||||
from core.workflow.variable_prefixes import (
|
||||
@ -126,7 +127,11 @@ class DraftVarLoader(VariableLoader):
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files.extend(value.value)
|
||||
with Session(bind=self._engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
|
||||
storage_key_loader = StorageKeyLoader(
|
||||
session,
|
||||
tenant_id=self._tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
storage_key_loader.load_storage_keys(files)
|
||||
|
||||
offloaded_draft_vars = []
|
||||
|
||||
@ -12,6 +12,7 @@ from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
@ -84,6 +85,8 @@ from .human_input_delivery_test_service import (
|
||||
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
|
||||
from .workflow_restore import apply_published_workflow_snapshot_to_draft
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""
|
||||
@ -1583,7 +1586,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
|
||||
if variable_entity_type == VariableEntityType.FILE:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"expected dict for file object, got {type(value)}")
|
||||
return build_from_mapping(mapping=value, tenant_id=tenant_id)
|
||||
return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller)
|
||||
elif variable_entity_type == VariableEntityType.FILE_LIST:
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"expected list for file list object, got {type(value)}")
|
||||
@ -1591,6 +1594,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
|
||||
return []
|
||||
if not isinstance(value[0], dict):
|
||||
raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
|
||||
return build_from_mappings(mappings=value, tenant_id=tenant_id)
|
||||
return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller)
|
||||
else:
|
||||
raise Exception("unreachable")
|
||||
|
||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(self.session, self.tenant_id)
|
||||
self.loader = StorageKeyLoader(
|
||||
self.session,
|
||||
self.tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
@ -361,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
# Create loader with different session (same underlying connection)
|
||||
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(other_session, self.tenant_id)
|
||||
other_loader = StorageKeyLoader(
|
||||
other_session,
|
||||
self.tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load_storage_keys([file])
|
||||
|
||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(self.session, self.tenant_id)
|
||||
self.loader = StorageKeyLoader(
|
||||
self.session,
|
||||
self.tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
@ -362,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
# Create loader with different session (same underlying connection)
|
||||
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(other_session, self.tenant_id)
|
||||
other_loader = StorageKeyLoader(
|
||||
other_session,
|
||||
self.tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load_storage_keys([file])
|
||||
|
||||
@ -401,11 +401,11 @@ class TestBaseAppGeneratorExtras:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
|
||||
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
|
||||
lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
|
||||
lambda mappings, tenant_id, config: ["file-1", "file-2"],
|
||||
lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"],
|
||||
)
|
||||
|
||||
user_inputs = {
|
||||
|
||||
@ -9,10 +9,13 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessScope
|
||||
from core.app.workflow import file_runtime
|
||||
from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
|
||||
def _build_file(
|
||||
@ -35,8 +38,12 @@ def _build_file(
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime() -> DifyWorkflowFileRuntime:
|
||||
return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())
|
||||
|
||||
|
||||
def test_resolve_file_url_returns_remote_url() -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
file = _build_file(
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/diagram.png",
|
||||
@ -46,7 +53,7 @@ def test_resolve_file_url_returns_remote_url() -> None:
|
||||
|
||||
|
||||
def test_resolve_file_url_requires_file_reference() -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing file reference"):
|
||||
@ -54,7 +61,7 @@ def test_resolve_file_url_requires_file_reference() -> None:
|
||||
|
||||
|
||||
def test_resolve_file_url_requires_extension_for_tool_files() -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
file = _build_file(
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
reference=build_file_reference(record_id="tool-file-id"),
|
||||
@ -70,7 +77,7 @@ def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files(
|
||||
) -> None:
|
||||
sign_tool_file = MagicMock(return_value="https://signed.example.com/file")
|
||||
monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file)
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
|
||||
tool_file = _build_file(
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
@ -100,7 +107,7 @@ def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments(
|
||||
"https://internal.example.com",
|
||||
)
|
||||
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
url = runtime.resolve_upload_file_url(
|
||||
upload_file_id="upload-file-id",
|
||||
as_attachment=True,
|
||||
@ -119,7 +126,7 @@ def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60)
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
payload = "file-preview|upload-file-id|1700000000|nonce"
|
||||
sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode()
|
||||
|
||||
@ -158,7 +165,7 @@ def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch
|
||||
|
||||
|
||||
def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
file = _build_file(
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
reference=build_file_reference(record_id="upload-file-id", storage_key="storage-key"),
|
||||
@ -173,7 +180,7 @@ def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest
|
||||
|
||||
|
||||
def test_resolve_storage_key_prefers_encoded_reference() -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
file = _build_file(
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
reference=build_file_reference(record_id="upload-file-id", storage_key="storage-key"),
|
||||
@ -182,6 +189,60 @@ def test_resolve_storage_key_prefers_encoded_reference() -> None:
|
||||
assert runtime._resolve_storage_key(file=file) == "storage-key"
|
||||
|
||||
|
||||
def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
controller = MagicMock()
|
||||
controller.current_scope.return_value = FileAccessScope(
|
||||
tenant_id="tenant-id",
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key")
|
||||
runtime = DifyWorkflowFileRuntime(file_access_controller=controller)
|
||||
file = _build_file(
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"),
|
||||
)
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext())
|
||||
|
||||
assert runtime._resolve_storage_key(file=file) == "canonical-storage-key"
|
||||
controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id")
|
||||
|
||||
|
||||
def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
controller = MagicMock()
|
||||
controller.current_scope.return_value = FileAccessScope(
|
||||
tenant_id="tenant-id",
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
controller.get_upload_file.return_value = None
|
||||
runtime = DifyWorkflowFileRuntime(file_access_controller=controller)
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext())
|
||||
|
||||
with pytest.raises(ValueError, match="Upload file upload-file-id not found"):
|
||||
runtime.resolve_upload_file_url(upload_file_id="upload-file-id")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("transfer_method", "record_id", "expected_storage_key"),
|
||||
[
|
||||
@ -196,7 +257,7 @@ def test_resolve_storage_key_loads_database_records(
|
||||
record_id: str,
|
||||
expected_storage_key: str,
|
||||
) -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
file = _build_file(
|
||||
transfer_method=transfer_method,
|
||||
reference=build_file_reference(record_id=record_id),
|
||||
@ -206,9 +267,9 @@ def test_resolve_storage_key_loads_database_records(
|
||||
|
||||
def get(model_class, value):
|
||||
if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}:
|
||||
assert model_class is file_runtime.UploadFile
|
||||
assert model_class is UploadFile
|
||||
return SimpleNamespace(key="upload-storage-key")
|
||||
assert model_class is file_runtime.ToolFile
|
||||
assert model_class is ToolFile
|
||||
return SimpleNamespace(file_key="tool-storage-key")
|
||||
|
||||
session.get.side_effect = get
|
||||
@ -237,7 +298,7 @@ def test_resolve_storage_key_raises_when_records_are_missing(
|
||||
transfer_method: FileTransferMethod,
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
runtime = _build_runtime()
|
||||
record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id"
|
||||
file = _build_file(
|
||||
transfer_method=transfer_method,
|
||||
|
||||
@ -115,6 +115,7 @@ def test_dify_file_reference_factory_passes_tenant_id(monkeypatch: pytest.Monkey
|
||||
build_from_mapping.assert_called_once_with(
|
||||
mapping={"id": "upload-file"},
|
||||
tenant_id="tenant-id",
|
||||
access_controller=node_runtime._file_access_controller,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -4,13 +4,17 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from httpx import Response
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
|
||||
from core.workflow.file_reference import build_file_reference, resolve_file_record_id
|
||||
from factories.file_factory import (
|
||||
File,
|
||||
FileTransferMethod,
|
||||
FileType,
|
||||
FileUploadConfig,
|
||||
build_from_mapping,
|
||||
)
|
||||
from factories.file_factory import (
|
||||
build_from_mapping as _build_from_mapping,
|
||||
)
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
@ -19,6 +23,7 @@ TEST_TENANT_ID = "test_tenant_id"
|
||||
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||
TEST_TOOL_FILE_ID = str(uuid.uuid4())
|
||||
TEST_REMOTE_URL = "http://example.com/test.jpg"
|
||||
TEST_ACCESS_CONTROLLER = DatabaseFileAccessController()
|
||||
|
||||
# Test Config
|
||||
TEST_CONFIG = FileUploadConfig(
|
||||
@ -29,6 +34,16 @@ TEST_CONFIG = FileUploadConfig(
|
||||
)
|
||||
|
||||
|
||||
def build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
|
||||
return _build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
strict_type_validation=strict_type_validation,
|
||||
access_controller=TEST_ACCESS_CONTROLLER,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_upload_file():
|
||||
@ -305,6 +320,48 @@ def test_tenant_mismatch():
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
|
||||
def test_build_from_mapping_scopes_upload_file_to_end_user(mock_upload_file):
|
||||
scope = FileAccessScope(
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
build_from_mapping(mapping=local_file_mapping(), tenant_id=TEST_TENANT_ID)
|
||||
|
||||
stmt = mock_upload_file.call_args.args[0]
|
||||
whereclause = str(stmt.whereclause)
|
||||
assert "upload_files.created_by_role" in whereclause
|
||||
assert "upload_files.created_by" in whereclause
|
||||
|
||||
|
||||
def test_build_from_mapping_scopes_tool_file_to_end_user():
|
||||
tool_file = MagicMock(spec=ToolFile)
|
||||
tool_file.id = TEST_TOOL_FILE_ID
|
||||
tool_file.tenant_id = TEST_TENANT_ID
|
||||
tool_file.name = "tool_file.pdf"
|
||||
tool_file.file_key = "tool_file.pdf"
|
||||
tool_file.mimetype = "application/pdf"
|
||||
tool_file.original_url = "http://example.com/tool.pdf"
|
||||
tool_file.size = 2048
|
||||
scope = FileAccessScope(
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=tool_file, autospec=True) as scalar:
|
||||
with bind_file_access_scope(scope):
|
||||
build_from_mapping(mapping=tool_file_mapping(), tenant_id=TEST_TENANT_ID)
|
||||
|
||||
stmt = scalar.call_args.args[0]
|
||||
whereclause = str(stmt.whereclause)
|
||||
assert "tool_files.user_id" in whereclause
|
||||
|
||||
|
||||
def test_disallowed_file_types(mock_upload_file):
|
||||
"""Test that disallowed file types are rejected."""
|
||||
# Config that only allows image and document types
|
||||
|
||||
@ -111,8 +111,8 @@ def test_inputs_resolve_owner_tenant_for_single_file_mapping(
|
||||
|
||||
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
|
||||
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
|
||||
_ = config, strict_type_validation
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
|
||||
_ = config, strict_type_validation, access_controller
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
@ -145,8 +145,8 @@ def test_inputs_resolve_owner_tenant_for_file_list_mapping(
|
||||
|
||||
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
|
||||
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
|
||||
_ = config, strict_type_validation
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
|
||||
_ = config, strict_type_validation, access_controller
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
@ -196,8 +196,8 @@ def test_inputs_prefer_serialized_tenant_id_when_present(
|
||||
|
||||
monkeypatch.setattr(model_module.db.session, "scalar", fail_if_called)
|
||||
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False):
|
||||
_ = config, strict_type_validation
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
|
||||
_ = config, strict_type_validation, access_controller
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
|
||||
|
||||
Reference in New Issue
Block a user