chore: inject account context in file handlers (#36655)

This commit is contained in:
Tianle
2026-05-26 00:43:57 -05:00
committed by GitHub
parent fd059720e5
commit 75d6511284
4 changed files with 59 additions and 62 deletions

View File

@ -22,10 +22,13 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.file_fields import FileResponse, UploadConfig
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.file_service import FileService
from . import console_ns
@ -62,8 +65,8 @@ class FileApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("documents")
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -107,10 +110,10 @@ class FilePreviewApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, file_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, file_id: UUID):
file_id_str = str(file_id)
_, tenant_id = current_account_with_tenant()
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
return {"content": text}

View File

@ -12,11 +12,13 @@ from controllers.common.errors import (
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.file_service import FileService
@ -49,7 +51,8 @@ class RemoteFileUpload(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
@login_required
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = payload.url
@ -74,12 +77,11 @@ class RemoteFileUpload(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
user=current_user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:

View File

@ -59,12 +59,8 @@ def mock_current_user():
@pytest.fixture
def mock_account_context(mock_current_user):
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
yield
def mock_current_tenant_id():
return "tenant-123"
@pytest.fixture
@ -95,15 +91,15 @@ class TestFileApiGet:
class TestFileApiPost:
def test_no_file_uploaded(self, app: Flask, mock_account_context):
def test_no_file_uploaded(self, app: Flask, mock_current_user):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST", data={}):
with pytest.raises(NoFileUploadedError):
post_method(api)
post_method(api, mock_current_user)
def test_too_many_files(self, app: Flask, mock_account_context):
def test_too_many_files(self, app: Flask, mock_current_user):
api = FileApi()
post_method = unwrap(api.post)
@ -118,9 +114,9 @@ class TestFileApiPost:
mock_request.form.get.return_value = None
with pytest.raises(TooManyFilesError):
post_method(api)
post_method(api, mock_current_user)
def test_filename_missing(self, app: Flask, mock_account_context):
def test_filename_missing(self, app: Flask, mock_current_user):
api = FileApi()
post_method = unwrap(api.post)
@ -130,28 +126,24 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(FilenameNotExistsError):
post_method(api)
post_method(api, mock_current_user)
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
mock_current_user.is_dataset_editor = False
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
api = FileApi()
post_method = unwrap(api.post)
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), "test.txt"),
"source": "datasets",
}
data = {
"file": (io.BytesIO(b"abc"), "test.txt"),
"source": "datasets",
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(Forbidden):
post_method(api)
with app.test_request_context(method="POST", data=data):
with pytest.raises(Forbidden):
post_method(api, mock_current_user)
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -179,13 +171,13 @@ class TestFileApiPost:
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
response, status = post_method(api, mock_current_user)
assert status == 201
assert response["id"] == "file-id-123"
assert response["name"] == "test.txt"
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
"""Test that invalid source parameter gets normalized to None"""
api = FileApi()
post_method = unwrap(api.post)
@ -216,7 +208,7 @@ class TestFileApiPost:
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
response, status = post_method(api, mock_current_user)
assert status == 201
assert response["id"] == "file-id-456"
@ -225,7 +217,7 @@ class TestFileApiPost:
call_kwargs = mock_file_service.upload_file.call_args[1]
assert call_kwargs["source"] is None
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -240,9 +232,9 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(FileTooLargeError):
post_method(api)
post_method(api, mock_current_user)
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -257,9 +249,9 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(UnsupportedFileTypeError):
post_method(api)
post_method(api, mock_current_user)
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -274,17 +266,17 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(BlockedFileExtensionError):
post_method(api)
post_method(api, mock_current_user)
class TestFilePreviewApi:
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
api = FilePreviewApi()
get_method = unwrap(api.get)
mock_file_service.get_file_preview.return_value = "preview text"
with app.test_request_context():
result = get_method(api, "1234")
result = get_method(api, mock_current_tenant_id, "1234")
assert result == {"content": "preview text"}

View File

@ -48,6 +48,7 @@ def _mock_upload_dependencies(
*,
file_size_within_limit: bool = True,
):
current_user = SimpleNamespace(id="u1")
file_info = SimpleNamespace(
filename="report.txt",
extension=".txt",
@ -63,7 +64,6 @@ def _mock_upload_dependencies(
file_service_cls = MagicMock()
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
remote_files_module.file_helpers,
@ -71,7 +71,7 @@ def _mock_upload_dependencies(
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
)
return file_service_cls
return file_service_cls, current_user
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -147,7 +147,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
get_mock = MagicMock(return_value=get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-1",
name="report.txt",
@ -160,7 +160,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
payload, status = handler(api, current_user)
assert status == 201
assert payload["id"] == "file-1"
@ -170,7 +170,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
filename="report.txt",
content=b"fallback-content",
mimetype="text/plain",
user=SimpleNamespace(id="u1"),
user=current_user,
source_url=url,
)
@ -191,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
get_mock = MagicMock(return_value=extra_get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-2",
name="photo.jpg",
@ -204,7 +204,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
payload, status = handler(api, current_user)
assert status == 201
assert payload["id"] == "file-2"
@ -226,7 +226,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
handler(api)
handler(api, SimpleNamespace(id="u1"))
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -243,7 +243,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
handler(api)
handler(api, SimpleNamespace(id="u1"))
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -258,11 +258,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
_, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError):
handler(api)
handler(api, current_user)
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -276,12 +276,12 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError, match="size exceeded"):
handler(api)
handler(api, current_user)
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -295,9 +295,9 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(UnsupportedFileTypeError):
handler(api)
handler(api, current_user)