mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 04:43:33 +08:00
chore: inject account context in file handlers (#36655)
This commit is contained in:
@ -22,10 +22,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
@ -62,8 +65,8 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
@ -107,10 +110,10 @@ class FilePreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, file_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -12,11 +12,13 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -49,7 +51,8 @@ class RemoteFileUpload(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
|
||||
@login_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = payload.url
|
||||
|
||||
@ -74,12 +77,11 @@ class RemoteFileUpload(Resource):
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
user=current_user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
|
||||
@ -59,12 +59,8 @@ def mock_current_user():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account_context(mock_current_user):
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
yield
|
||||
def mock_current_tenant_id():
|
||||
return "tenant-123"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -95,15 +91,15 @@ class TestFileApiGet:
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app: Flask, mock_account_context):
|
||||
def test_no_file_uploaded(self, app: Flask, mock_current_user):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_too_many_files(self, app: Flask, mock_account_context):
|
||||
def test_too_many_files(self, app: Flask, mock_current_user):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -118,9 +114,9 @@ class TestFileApiPost:
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_filename_missing(self, app: Flask, mock_account_context):
|
||||
def test_filename_missing(self, app: Flask, mock_current_user):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -130,28 +126,24 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api)
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -179,13 +171,13 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
response, status = post_method(api, mock_current_user)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
@ -216,7 +208,7 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
response, status = post_method(api, mock_current_user)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
@ -225,7 +217,7 @@ class TestFileApiPost:
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -240,9 +232,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -257,9 +249,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -274,17 +266,17 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api)
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
|
||||
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, "1234")
|
||||
result = get_method(api, mock_current_tenant_id, "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
@ -48,6 +48,7 @@ def _mock_upload_dependencies(
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
current_user = SimpleNamespace(id="u1")
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
@ -63,7 +64,6 @@ def _mock_upload_dependencies(
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
@ -71,7 +71,7 @@ def _mock_upload_dependencies(
|
||||
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
|
||||
)
|
||||
|
||||
return file_service_cls
|
||||
return file_service_cls, current_user
|
||||
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -147,7 +147,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
get_mock = MagicMock(return_value=get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="report.txt",
|
||||
@ -160,7 +160,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
payload, status = handler(api, current_user)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-1"
|
||||
@ -170,7 +170,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
filename="report.txt",
|
||||
content=b"fallback-content",
|
||||
mimetype="text/plain",
|
||||
user=SimpleNamespace(id="u1"),
|
||||
user=current_user,
|
||||
source_url=url,
|
||||
)
|
||||
|
||||
@ -191,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
get_mock = MagicMock(return_value=extra_get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-2",
|
||||
name="photo.jpg",
|
||||
@ -204,7 +204,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
payload, status = handler(api, current_user)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-2"
|
||||
@ -226,7 +226,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api)
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -243,7 +243,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api)
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -258,11 +258,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
|
||||
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
_, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
handler(api)
|
||||
handler(api, current_user)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -276,12 +276,12 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError, match="size exceeded"):
|
||||
handler(api)
|
||||
handler(api, current_user)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -295,9 +295,9 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
handler(api)
|
||||
handler(api, current_user)
|
||||
|
||||
Reference in New Issue
Block a user