From 75d6511284c76e3a9ce2d4d5660ff7ce20d021cd Mon Sep 17 00:00:00 2001 From: Tianle <40735546+Tianlel@users.noreply.github.com> Date: Tue, 26 May 2026 00:43:57 -0500 Subject: [PATCH] chore: inject account context in file handlers (#36655) --- api/controllers/console/files.py | 15 +++-- api/controllers/console/remote_files.py | 10 +-- .../controllers/console/test_files.py | 66 ++++++++----------- .../controllers/console/test_remote_files.py | 30 ++++----- 4 files changed, 59 insertions(+), 62 deletions(-) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 499a623872..3ef006c051 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -22,10 +22,13 @@ from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, + with_current_tenant_id, + with_current_user, ) from extensions.ext_database import db from fields.file_fields import FileResponse, UploadConfig -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from services.file_service import FileService from . import console_ns @@ -62,8 +65,8 @@ class FileApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("documents") @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None @@ -107,10 +110,10 @@ class FilePreviewApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__]) - def get(self, file_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, file_id: UUID): file_id_str = str(file_id) - _, tenant_id = current_account_with_tenant() - text = FileService(db.engine).get_file_preview(file_id_str, tenant_id) + text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id) return {"content": text} diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 19f1fd8aab..93435d1151 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -12,11 +12,13 @@ from controllers.common.errors import ( ) from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns +from controllers.console.wraps import with_current_user from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from services.file_service import FileService @@ -49,7 +51,8 @@ class RemoteFileUpload(Resource): @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) @console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__]) @login_required - def post(self): + @with_current_user + def post(self, current_user: Account): payload = RemoteFileUploadPayload.model_validate(console_ns.payload) url = payload.url @@ -74,12 +77,11 @@ class RemoteFileUpload(Resource): content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: - user, _ = current_account_with_tenant() upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, mimetype=file_info.mimetype, - user=user, + user=current_user, source_url=url, ) except services.errors.file.FileTooLargeError as file_too_large_error: diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index 9274f6cf61..d566486664 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -59,12 +59,8 @@ def mock_current_user(): @pytest.fixture -def mock_account_context(mock_current_user): - with patch( - "controllers.console.files.current_account_with_tenant", - return_value=(mock_current_user, None), - ): - yield +def mock_current_tenant_id(): + return "tenant-123" @pytest.fixture @@ -95,15 +91,15 @@ class TestFileApiGet: class TestFileApiPost: - def test_no_file_uploaded(self, app: Flask, mock_account_context): + def test_no_file_uploaded(self, app: Flask, mock_current_user): api = FileApi() post_method = unwrap(api.post) with app.test_request_context(method="POST", data={}): with pytest.raises(NoFileUploadedError): - post_method(api) + post_method(api, mock_current_user) - def test_too_many_files(self, app: Flask, mock_account_context): + def test_too_many_files(self, app: Flask, mock_current_user): api = FileApi() post_method = unwrap(api.post) @@ -118,9 +114,9 @@ class TestFileApiPost: mock_request.form.get.return_value = None with pytest.raises(TooManyFilesError): - post_method(api) + post_method(api, mock_current_user) - def test_filename_missing(self, app: Flask, mock_account_context): + def test_filename_missing(self, app: Flask, mock_current_user): api = FileApi() post_method = unwrap(api.post) @@ -130,28 +126,24 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(FilenameNotExistsError): - post_method(api) + post_method(api, mock_current_user) def test_dataset_upload_without_permission(self, app: Flask, mock_current_user): mock_current_user.is_dataset_editor = False - with patch( - "controllers.console.files.current_account_with_tenant", - return_value=(mock_current_user, None), - ): - api = FileApi() - post_method = unwrap(api.post) + api = FileApi() + post_method = unwrap(api.post) - data = { - "file": (io.BytesIO(b"abc"), "test.txt"), - "source": "datasets", - } + data = { + "file": (io.BytesIO(b"abc"), "test.txt"), + "source": "datasets", + } - with app.test_request_context(method="POST", data=data): - with pytest.raises(Forbidden): - post_method(api) + with app.test_request_context(method="POST", data=data): + with pytest.raises(Forbidden): + post_method(api, mock_current_user) - def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service): + def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -179,13 +171,13 @@ class TestFileApiPost: } with app.test_request_context(method="POST", data=data): - response, status = post_method(api) + response, status = post_method(api, mock_current_user) assert status == 201 assert response["id"] == "file-id-123" assert response["name"] == "test.txt" - def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service): + def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service): """Test that invalid source parameter gets normalized to None""" api = FileApi() post_method = unwrap(api.post) @@ -216,7 +208,7 @@ class TestFileApiPost: } with app.test_request_context(method="POST", data=data): - response, status = post_method(api) + response, status = post_method(api, mock_current_user) assert status == 201 assert response["id"] == "file-id-456" @@ -225,7 +217,7 @@ class TestFileApiPost: call_kwargs = mock_file_service.upload_file.call_args[1] assert call_kwargs["source"] is None - def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service): + def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -240,9 +232,9 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(FileTooLargeError): - post_method(api) + post_method(api, mock_current_user) - def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service): + def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -257,9 +249,9 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(UnsupportedFileTypeError): - post_method(api) + post_method(api, mock_current_user) - def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service): + def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -274,17 +266,17 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(BlockedFileExtensionError): - post_method(api) + post_method(api, mock_current_user) class TestFilePreviewApi: - def test_get_preview(self, app: Flask, mock_account_context, mock_file_service): + def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service): api = FilePreviewApi() get_method = unwrap(api.get) mock_file_service.get_file_preview.return_value = "preview text" with app.test_request_context(): - result = get_method(api, "1234") + result = get_method(api, mock_current_tenant_id, "1234") assert result == {"content": "preview text"} diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index 8e86709b66..ae620b1e52 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -48,6 +48,7 @@ def _mock_upload_dependencies( *, file_size_within_limit: bool = True, ): + current_user = SimpleNamespace(id="u1") file_info = SimpleNamespace( filename="report.txt", extension=".txt", @@ -63,7 +64,6 @@ def _mock_upload_dependencies( file_service_cls = MagicMock() file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit monkeypatch.setattr(remote_files_module, "FileService", file_service_cls) - monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None)) monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( remote_files_module.file_helpers, @@ -71,7 +71,7 @@ def _mock_upload_dependencies( lambda upload_file_id: f"https://signed.example/{upload_file_id}", ) - return file_service_cls + return file_service_cls, current_user def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -147,7 +147,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc get_mock = MagicMock(return_value=get_resp) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) upload_file = SimpleNamespace( id="file-1", name="report.txt", @@ -160,7 +160,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc file_service_cls.return_value.upload_file.return_value = upload_file with app.test_request_context(method="POST", json={"url": url}): - payload, status = handler(api) + payload, status = handler(api, current_user) assert status == 201 assert payload["id"] == "file-1" @@ -170,7 +170,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc filename="report.txt", content=b"fallback-content", mimetype="text/plain", - user=SimpleNamespace(id="u1"), + user=current_user, source_url=url, ) @@ -191,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( get_mock = MagicMock(return_value=extra_get_resp) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) upload_file = SimpleNamespace( id="file-2", name="photo.jpg", @@ -204,7 +204,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( file_service_cls.return_value.upload_file.return_value = upload_file with app.test_request_context(method="POST", json={"url": url}): - payload, status = handler(api) + payload, status = handler(api, current_user) assert status == 201 assert payload["id"] == "file-2" @@ -226,7 +226,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"): - handler(api) + handler(api, SimpleNamespace(id="u1")) def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -243,7 +243,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"): - handler(api) + handler(api, SimpleNamespace(id="u1")) def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -258,11 +258,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk ) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) - _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) + _, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(FileTooLargeError): - handler(api) + handler(api, current_user) def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -276,12 +276,12 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded") with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(FileTooLargeError, match="size exceeded"): - handler(api) + handler(api, current_user) def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -295,9 +295,9 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError() with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(UnsupportedFileTypeError): - handler(api) + handler(api, current_user)