mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
chore: inject account context in file handlers (#36655)
This commit is contained in:
@ -59,12 +59,8 @@ def mock_current_user():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account_context(mock_current_user):
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
yield
|
||||
def mock_current_tenant_id():
|
||||
return "tenant-123"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -95,15 +91,15 @@ class TestFileApiGet:
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app: Flask, mock_account_context):
|
||||
def test_no_file_uploaded(self, app: Flask, mock_current_user):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_too_many_files(self, app: Flask, mock_account_context):
|
||||
def test_too_many_files(self, app: Flask, mock_current_user):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -118,9 +114,9 @@ class TestFileApiPost:
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_filename_missing(self, app: Flask, mock_account_context):
|
||||
def test_filename_missing(self, app: Flask, mock_current_user):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -130,28 +126,24 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api)
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -179,13 +171,13 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
response, status = post_method(api, mock_current_user)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
@ -216,7 +208,7 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
response, status = post_method(api, mock_current_user)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
@ -225,7 +217,7 @@ class TestFileApiPost:
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -240,9 +232,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -257,9 +249,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -274,17 +266,17 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, "1234")
|
||||
result = get_method(api, mock_current_tenant_id, "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
@ -48,6 +48,7 @@ def _mock_upload_dependencies(
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
current_user = SimpleNamespace(id="u1")
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
@ -63,7 +64,6 @@ def _mock_upload_dependencies(
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
@ -71,7 +71,7 @@ def _mock_upload_dependencies(
|
||||
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
|
||||
)
|
||||
|
||||
return file_service_cls
|
||||
return file_service_cls, current_user
|
||||
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -147,7 +147,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
get_mock = MagicMock(return_value=get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="report.txt",
|
||||
@ -160,7 +160,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
payload, status = handler(api, current_user)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-1"
|
||||
@ -170,7 +170,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
filename="report.txt",
|
||||
content=b"fallback-content",
|
||||
mimetype="text/plain",
|
||||
user=SimpleNamespace(id="u1"),
|
||||
user=current_user,
|
||||
source_url=url,
|
||||
)
|
||||
|
||||
@ -191,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
get_mock = MagicMock(return_value=extra_get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-2",
|
||||
name="photo.jpg",
|
||||
@ -204,7 +204,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
payload, status = handler(api, current_user)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-2"
|
||||
@ -226,7 +226,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api)
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -243,7 +243,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api)
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -258,11 +258,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
|
||||
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
_, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
handler(api)
|
||||
handler(api, current_user)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -276,12 +276,12 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError, match="size exceeded"):
|
||||
handler(api)
|
||||
handler(api, current_user)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -295,9 +295,9 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
handler(api)
|
||||
handler(api, current_user)
|
||||
|
||||
Reference in New Issue
Block a user