feat(backend): file permissions and ownership (#38)
* feat(backend): file permissions and utilities * feat(backend): file upload/list permissions * refactor(backend): class-based middleware to play nicer with tests * test(backend): async fixtures + proper foreign key cascade * fix(backend): temporarily bypass auth
This commit is contained in:
parent
d31f73c66a
commit
23248d0277
14 changed files with 278 additions and 130 deletions
|
@ -433,7 +433,9 @@ disable=raw-checker-failed,
|
|||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
too-many-locals,
|
||||
line-too-long
|
||||
line-too-long,
|
||||
too-few-public-methods,
|
||||
fixme
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
-c requirements.txt
|
||||
|
||||
anyio~=3.7.0
|
||||
black~=23.7.0
|
||||
pylint~=2.17.0
|
||||
pytest
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
anyio==3.7.1
|
||||
# via
|
||||
# -c requirements.txt
|
||||
# -r requirements_dev.in
|
||||
# httpcore
|
||||
astroid==2.15.6
|
||||
# via pylint
|
||||
|
|
|
@ -1,23 +1,18 @@
|
|||
"""
|
||||
Authentication & authorization middleware logic.
|
||||
|
||||
This module is imported dynamically from `settings` to set up
|
||||
middlewares with the `FastAPI` singleton.
|
||||
"""
|
||||
import logging
|
||||
|
||||
import jwt.exceptions
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
import auth.use_cases as auth_use_cases
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication_middleware(request: Request, call_next):
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Decodes Authorization headers if present on the request and sets
|
||||
identifying fields in the request state.
|
||||
|
@ -25,22 +20,23 @@ async def authentication_middleware(request: Request, call_next):
|
|||
This information is then leveraged by individual routes to determine
|
||||
authorization.
|
||||
"""
|
||||
auth_header = request.headers.get("authorization")
|
||||
decoded_token = None
|
||||
|
||||
if auth_header is not None:
|
||||
_, token = auth_header.split(" ")
|
||||
try:
|
||||
decoded_token = auth_use_cases.decode_token(token)
|
||||
except jwt.exceptions.ExpiredSignatureError as exc:
|
||||
logger.exception(exc)
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
auth_header = request.headers.get("authorization")
|
||||
decoded_token = None
|
||||
|
||||
if decoded_token is not None:
|
||||
logger.info(decoded_token)
|
||||
request.state.user = {
|
||||
"username": decoded_token["username"],
|
||||
"user_id": decoded_token["user_id"],
|
||||
}
|
||||
if auth_header is not None:
|
||||
_, token = auth_header.split(" ")
|
||||
try:
|
||||
decoded_token = auth_use_cases.decode_token(token)
|
||||
except jwt.exceptions.ExpiredSignatureError as exc:
|
||||
logger.exception(exc)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
if decoded_token is not None:
|
||||
logger.info(decoded_token)
|
||||
request.state.user = {
|
||||
"username": decoded_token["username"],
|
||||
"user_id": decoded_token["user_id"],
|
||||
}
|
||||
|
||||
return await call_next(request)
|
||||
|
|
13
backend/rotini/files/base.py
Normal file
13
backend/rotini/files/base.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import typing_extensions as typing
|
||||
|
||||
|
||||
class FileRecord(typing.TypedDict):
|
||||
"""
|
||||
Database record associated with a file tracked
|
||||
by the system.
|
||||
"""
|
||||
|
||||
id: str
|
||||
size: int
|
||||
path: str
|
||||
filename: str
|
|
@ -7,23 +7,39 @@ files that live in the system.
|
|||
|
||||
import pathlib
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, Request
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
import files.use_cases as files_use_cases
|
||||
|
||||
from settings import settings
|
||||
|
||||
router = APIRouter(prefix="/files")
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def list_files():
|
||||
return files_use_cases.get_all_file_records()
|
||||
@router.get("/", status_code=200)
|
||||
async def list_files(request: Request):
|
||||
"""
|
||||
Fetches all files owned by the logged-in user.
|
||||
|
||||
200 { [<FileRecord>, ...] }
|
||||
|
||||
If the user is logged in, file records that they
|
||||
own are returned.
|
||||
|
||||
401 {}
|
||||
|
||||
If the request is not authenticated, it fails.
|
||||
"""
|
||||
# FIXME: Temporarily fetching files belonging to the base user.
|
||||
# to be resolved once users can log in.
|
||||
current_user_id = (
|
||||
request.state.user["user_id"] if hasattr(request.state, "user") else 1
|
||||
)
|
||||
return files_use_cases.get_all_files_owned_by_user(current_user_id)
|
||||
|
||||
|
||||
@router.post("/", status_code=201)
|
||||
async def upload_file(file: UploadFile) -> files_use_cases.FileRecord:
|
||||
async def upload_file(request: Request, file: UploadFile) -> files_use_cases.FileRecord:
|
||||
"""
|
||||
Receives files uploaded by the user, saving them to disk and
|
||||
recording their existence in the database.
|
||||
|
@ -39,8 +55,13 @@ async def upload_file(file: UploadFile) -> files_use_cases.FileRecord:
|
|||
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
created_record = files_use_cases.create_file_record(str(dest_path), size)
|
||||
# FIXME: Temporarily fetching files belonging to the base user.
|
||||
# to be resolved once users can log in.
|
||||
created_record = files_use_cases.create_file_record(
|
||||
str(dest_path),
|
||||
size,
|
||||
request.state.user["user_id"] if hasattr(request.state, "user") else 1,
|
||||
)
|
||||
|
||||
return created_record
|
||||
|
||||
|
|
|
@ -12,22 +12,15 @@ import typing_extensions as typing
|
|||
from db import get_connection
|
||||
from settings import settings
|
||||
|
||||
from permissions.base import Permissions
|
||||
from permissions.files import set_file_permission
|
||||
|
||||
from exceptions import DoesNotExist
|
||||
|
||||
|
||||
class FileRecord(typing.TypedDict):
|
||||
"""
|
||||
Database record associated with a file tracked
|
||||
by the system.
|
||||
"""
|
||||
|
||||
id: str
|
||||
size: int
|
||||
path: str
|
||||
filename: str
|
||||
from files.base import FileRecord
|
||||
|
||||
|
||||
def create_file_record(path: str, size: int) -> FileRecord:
|
||||
def create_file_record(path: str, size: int, owner_id: int) -> FileRecord:
|
||||
"""
|
||||
Creates a record representing an uploaded file in the database.
|
||||
|
||||
|
@ -43,20 +36,34 @@ def create_file_record(path: str, size: int) -> FileRecord:
|
|||
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
|
||||
set_file_permission(inserted_id, owner_id, list(Permissions))
|
||||
|
||||
filename = pathlib.Path(path).name
|
||||
|
||||
return FileRecord(id=inserted_id, size=size, path=path, filename=filename)
|
||||
|
||||
|
||||
def get_all_file_records() -> typing.Tuple[FileRecord]:
|
||||
"""
|
||||
Fetches all availables files from the database.
|
||||
def get_all_files_owned_by_user(user_id: int) -> typing.Tuple[FileRecord]:
|
||||
"""
|
||||
Gets all the file records owned by the user.
|
||||
|
||||
A file is considered owned if the user has all permissions on a given file. There
|
||||
can be more than one owner to a file, but all files must have an owner.
|
||||
"""
|
||||
rows = None
|
||||
|
||||
with get_connection() as connection, connection.cursor() as cursor:
|
||||
cursor.execute("SELECT * FROM files;")
|
||||
cursor.execute(
|
||||
"""SELECT
|
||||
f.*
|
||||
from files f
|
||||
join permissions_files pf
|
||||
on f.id = pf.file_id
|
||||
where
|
||||
pf.user_id = %s
|
||||
and pf.value = %s;""",
|
||||
(user_id, sum(p.value for p in Permissions)),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if rows is None:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
"""
|
||||
Rotini: a self-hosted cloud storage & productivity app.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
import auth.middleware as auth_middleware
|
||||
import auth.routes as auth_routes
|
||||
|
||||
import files.routes as files_routes
|
||||
|
@ -21,16 +21,13 @@ app.add_middleware(
|
|||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(auth_middleware.AuthenticationMiddleware)
|
||||
|
||||
routers = [files_routes.router, auth_routes.router]
|
||||
middlewares = ["auth.middleware"]
|
||||
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
|
||||
for middleware in middlewares:
|
||||
importlib.import_module(middleware)
|
||||
|
||||
|
||||
@app.get("/", status_code=204)
|
||||
def healthcheck():
|
||||
|
|
27
backend/rotini/migrations/migration_2_permissions.py
Normal file
27
backend/rotini/migrations/migration_2_permissions.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
"""
|
||||
Generated: 2023-08-27T11:56:17.800102
|
||||
|
||||
Message: Sets up permission-tracking on files
|
||||
"""
|
||||
UID = "3c755dd8-e02d-4a29-b4ee-2afa4d9b30d6"
|
||||
|
||||
PARENT = "141faa0b-6868-4d07-a24b-b45f98d2809d"
|
||||
|
||||
MESSAGE = "Sets up permission-tracking on files"
|
||||
|
||||
UP_SQL = """CREATE TABLE
|
||||
permissions_files
|
||||
(
|
||||
id bigserial PRIMARY KEY,
|
||||
file_id uuid NOT NULL,
|
||||
user_id bigint NOT NULL,
|
||||
value bigint NOT NULL,
|
||||
created_at timestamp DEFAULT now(),
|
||||
updated_at timestamp DEFAULT now(),
|
||||
CONSTRAINT file_fk FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE,
|
||||
CONSTRAINT user_fk FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
CONSTRAINT unique_permission_per_file_per_user UNIQUE(file_id, user_id)
|
||||
);
|
||||
"""
|
||||
|
||||
DOWN_SQL = """DROP TABLE permissions_files;"""
|
23
backend/rotini/permissions/base.py
Normal file
23
backend/rotini/permissions/base.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import enum
|
||||
|
||||
import typing_extensions as typing
|
||||
|
||||
|
||||
class Permissions(enum.Enum):
|
||||
"""
|
||||
Enumeration of individual permission bits.
|
||||
|
||||
Complex permissions are composed by combining these
|
||||
bits.
|
||||
"""
|
||||
|
||||
CAN_VIEW = 1 << 0
|
||||
CAN_DELETE = 1 << 1
|
||||
|
||||
|
||||
class FilePermission(typing.TypedDict):
|
||||
"""Representation of a permission applicable to a file+user pair"""
|
||||
|
||||
file: str
|
||||
user: int
|
||||
value: typing.List[Permissions]
|
26
backend/rotini/permissions/files.py
Normal file
26
backend/rotini/permissions/files.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import typing_extensions as typing
|
||||
|
||||
from permissions.base import Permissions, FilePermission
|
||||
from db import get_connection
|
||||
|
||||
|
||||
def set_file_permission(
|
||||
file_id: str, user_id: int, permissions: typing.List[Permissions]
|
||||
) -> FilePermission:
|
||||
"""
|
||||
Given a file+user pair, creates a permission record with the
|
||||
provided permission list.
|
||||
"""
|
||||
permission_value = sum(permission.value for permission in permissions)
|
||||
|
||||
with get_connection() as connection, connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"INSERT INTO permissions_files (user_id, file_id, value) VALUES (%s, %s, %s) RETURNING id;",
|
||||
(user_id, file_id, permission_value),
|
||||
)
|
||||
inserted_row = cursor.fetchone()
|
||||
|
||||
if inserted_row is None:
|
||||
raise RuntimeError("uh")
|
||||
|
||||
return FilePermission(file=file_id, user=user_id, value=permissions)
|
|
@ -1,24 +1,26 @@
|
|||
"""
|
||||
Global fixtures
|
||||
|
||||
|
||||
"""
|
||||
from fastapi.testclient import TestClient
|
||||
import httpx
|
||||
import pytest
|
||||
import unittest.mock
|
||||
|
||||
from rotini.main import app
|
||||
from rotini.db import get_connection
|
||||
|
||||
from main import app
|
||||
from db import get_connection
|
||||
from settings import settings
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
def fixture_client():
|
||||
return TestClient(app)
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.fixture(autouse=False)
|
||||
def reset_database():
|
||||
"""
|
||||
Empties all user tables between tests.
|
||||
"""
|
||||
tables = ["files", "users"]
|
||||
"""Empties all user tables between tests."""
|
||||
tables = ["files", "users", "permissions_files"]
|
||||
|
||||
with get_connection() as conn, conn.cursor() as cursor:
|
||||
for table in tables:
|
||||
|
@ -26,7 +28,7 @@ def reset_database():
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_storage_path(tmp_path, monkeypatch):
|
||||
async def set_storage_path(tmp_path, monkeypatch):
|
||||
"""
|
||||
Ensures that files stored by tests are stored
|
||||
in temporary directories.
|
||||
|
@ -38,17 +40,56 @@ def set_storage_path(tmp_path, monkeypatch):
|
|||
monkeypatch.setattr(settings, "STORAGE_ROOT", str(files_dir))
|
||||
|
||||
|
||||
@pytest.fixture(name="test_user_credentials")
|
||||
def fixture_test_user_creds():
|
||||
"""
|
||||
Test user credentials.
|
||||
"""
|
||||
return {"username": "testuser", "password": "testpassword"}
|
||||
|
||||
|
||||
@pytest.fixture(name="test_user", autouse=True)
|
||||
async def fixture_test_user(client_create_user, test_user_credentials):
|
||||
"""
|
||||
Sets up a test user using the `test_user_credentials` data.
|
||||
"""
|
||||
yield await client_create_user(test_user_credentials)
|
||||
|
||||
|
||||
@pytest.fixture(name="no_auth_client")
|
||||
async def fixture_no_auth_client():
|
||||
"""HTTP client without any authentication"""
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(name="jwt_client")
|
||||
async def fixture_jwt_client(client_log_in, test_user_credentials):
|
||||
"""HTTP client with test user authentication via JWT"""
|
||||
response = await client_log_in(test_user_credentials)
|
||||
auth_header = response.headers["authorization"]
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
app=app, base_url="http://test", headers={"Authorization": auth_header}
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(name="client_log_in")
|
||||
def fixture_client_log_in(client):
|
||||
def _client_log_in(credentials):
|
||||
return client.post("/auth/sessions/", json=credentials)
|
||||
def fixture_client_log_in(no_auth_client):
|
||||
"""Logs in as the provided user"""
|
||||
|
||||
async def _client_log_in(credentials):
|
||||
return await no_auth_client.post("/auth/sessions/", json=credentials)
|
||||
|
||||
return _client_log_in
|
||||
|
||||
|
||||
@pytest.fixture(name="client_create_user")
|
||||
def fixture_client_create_user(client):
|
||||
def _client_create_user(credentials):
|
||||
return client.post("/auth/users/", json=credentials)
|
||||
def fixture_client_create_user(no_auth_client):
|
||||
"""Creates a new user given credentials"""
|
||||
|
||||
async def _client_create_user(credentials):
|
||||
return await no_auth_client.post("/auth/users/", json=credentials)
|
||||
|
||||
return _client_create_user
|
||||
|
|
|
@ -1,36 +1,21 @@
|
|||
import pytest
|
||||
import jwt
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
@pytest.fixture(name="test_user_credentials")
|
||||
def fixture_test_user_creds():
|
||||
"""
|
||||
Test user credentials.
|
||||
"""
|
||||
return {"username": "testuser", "password": "testpassword"}
|
||||
|
||||
|
||||
@pytest.fixture(name="test_user", autouse=True)
|
||||
def fixture_test_user(client_create_user, test_user_credentials):
|
||||
"""
|
||||
Sets up a test user using the `test_user_credentials` data.
|
||||
"""
|
||||
yield client_create_user(test_user_credentials)
|
||||
|
||||
|
||||
def test_create_user_returns_201_on_success(client_create_user):
|
||||
async def test_create_user_returns_201_on_success(client_create_user):
|
||||
credentials = {"username": "newuser", "password": "test"}
|
||||
response = client_create_user(credentials)
|
||||
response = await client_create_user(credentials)
|
||||
|
||||
assert response.status_code == 201
|
||||
|
||||
|
||||
def test_create_user_with_nonunique_username_fails(client_create_user):
|
||||
async def test_create_user_with_nonunique_username_fails(client_create_user):
|
||||
credentials = {"username": "newuser", "password": "test"}
|
||||
client_create_user(credentials)
|
||||
await client_create_user(credentials)
|
||||
|
||||
# Recreate the same user, name collision.
|
||||
response = client_create_user(credentials)
|
||||
response = await client_create_user(credentials)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
@ -43,18 +28,20 @@ def test_create_user_with_nonunique_username_fails(client_create_user):
|
|||
pytest.param({}, id="no_data"),
|
||||
],
|
||||
)
|
||||
def test_create_user_requires_username_and_password_supplied(
|
||||
async def test_create_user_requires_username_and_password_supplied(
|
||||
client_create_user, credentials
|
||||
):
|
||||
response = client_create_user(credentials)
|
||||
response = await client_create_user(credentials)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credentials):
|
||||
async def test_log_in_returns_200_and_user_on_success(
|
||||
client_log_in, test_user_credentials
|
||||
):
|
||||
# The `test_user` fixture creates a user.
|
||||
|
||||
response = client_log_in(test_user_credentials)
|
||||
response = await client_log_in(test_user_credentials)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
@ -63,7 +50,7 @@ def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credent
|
|||
assert returned["username"] == test_user_credentials["username"]
|
||||
|
||||
|
||||
def test_log_in_attaches_identity_token_to_response_on_success(
|
||||
async def test_log_in_attaches_identity_token_to_response_on_success(
|
||||
client_log_in, test_user_credentials
|
||||
):
|
||||
# This test specifically needs to inspect the JWT, hence the need to access
|
||||
|
@ -71,7 +58,7 @@ def test_log_in_attaches_identity_token_to_response_on_success(
|
|||
|
||||
import auth.use_cases as auth_use_cases
|
||||
|
||||
response = client_log_in(test_user_credentials)
|
||||
response = await client_log_in(test_user_credentials)
|
||||
|
||||
returned_auth = response.headers.get("authorization")
|
||||
token = returned_auth.split(" ")[1] # Header of the form "Bearer <token>"
|
||||
|
@ -82,16 +69,18 @@ def test_log_in_attaches_identity_token_to_response_on_success(
|
|||
)
|
||||
|
||||
|
||||
def test_log_in_returns_401_on_wrong_password(client_log_in, test_user_credentials):
|
||||
response = client_log_in(
|
||||
async def test_log_in_returns_401_on_wrong_password(
|
||||
client_log_in, test_user_credentials
|
||||
):
|
||||
response = await client_log_in(
|
||||
{"username": test_user_credentials["username"], "password": "sillystring"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_log_in_returns_401_on_nonexistent_user(client_log_in):
|
||||
response = client_log_in({"username": "notauser", "password": "sillystring"})
|
||||
async def test_log_in_returns_401_on_nonexistent_user(client_log_in):
|
||||
response = await client_log_in({"username": "notauser", "password": "sillystring"})
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
@ -104,7 +93,7 @@ def test_log_in_returns_401_on_nonexistent_user(client_log_in):
|
|||
pytest.param({}, id="no_data"),
|
||||
],
|
||||
)
|
||||
def test_log_in_returns_422_on_invalid_input(client_log_in, credentials):
|
||||
response = client_log_in(credentials)
|
||||
async def test_log_in_returns_422_on_invalid_input(client_log_in, credentials):
|
||||
response = await client_log_in(credentials)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
def test_list_files_returns_registered_files_and_200(client, tmp_path):
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
async def test_list_files_returns_registered_files_and_200(jwt_client, tmp_path):
|
||||
mock_file_1 = tmp_path / "test1.txt"
|
||||
mock_file_1.write_text("testtest")
|
||||
|
||||
with open(str(mock_file_1), "rb") as mock_file_stream:
|
||||
response = client.post("/files/", files={"file": mock_file_stream})
|
||||
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
|
||||
|
||||
mock_file_1_data = response.json()
|
||||
|
||||
|
@ -14,104 +18,104 @@ def test_list_files_returns_registered_files_and_200(client, tmp_path):
|
|||
mock_file_2.write_text("testtest")
|
||||
|
||||
with open(str(mock_file_2), "rb") as mock_file_stream:
|
||||
response = client.post("/files/", files={"file": mock_file_stream})
|
||||
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
|
||||
|
||||
mock_file_2_data = response.json()
|
||||
|
||||
response = client.get("/files/")
|
||||
response = await jwt_client.get("/files/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [mock_file_1_data, mock_file_2_data]
|
||||
|
||||
|
||||
def test_file_details_returns_specified_file_and_200(client, tmp_path):
|
||||
async def test_file_details_returns_specified_file_and_200(jwt_client, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("testtest")
|
||||
|
||||
with open(str(mock_file), "rb") as mock_file_stream:
|
||||
response = client.post("/files/", files={"file": mock_file_stream})
|
||||
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
|
||||
|
||||
response_data = response.json()
|
||||
created_file_id = response_data["id"]
|
||||
|
||||
response = client.get(f"/files/{created_file_id}/")
|
||||
response = await jwt_client.get(f"/files/{created_file_id}/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == response_data
|
||||
|
||||
|
||||
def test_file_details_returns_404_if_does_not_exist(client):
|
||||
async def test_file_details_returns_404_if_does_not_exist(jwt_client):
|
||||
non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a"
|
||||
response = client.get(f"/files/{non_existent_id}/")
|
||||
response = await jwt_client.get(f"/files/{non_existent_id}/")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_file_deletion_returns_404_if_does_not_exist(client):
|
||||
async def test_file_deletion_returns_404_if_does_not_exist(jwt_client):
|
||||
non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a"
|
||||
response = client.delete(f"/files/{non_existent_id}/")
|
||||
response = await jwt_client.delete(f"/files/{non_existent_id}/")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_file_deletion_deletes_record_and_file(client, tmp_path):
|
||||
async def test_file_deletion_deletes_record_and_file(jwt_client, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("testtest")
|
||||
|
||||
with open(str(mock_file), "rb") as mock_file_stream:
|
||||
response = client.post("/files/", files={"file": mock_file_stream})
|
||||
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
|
||||
|
||||
response_data = response.json()
|
||||
file_id = response_data["id"]
|
||||
file_path = response_data["path"]
|
||||
|
||||
assert pathlib.Path(file_path).exists()
|
||||
response = client.get(f"/files/{file_id}/")
|
||||
response = await jwt_client.get(f"/files/{file_id}/")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
client.delete(f"/files/{file_id}/")
|
||||
await jwt_client.delete(f"/files/{file_id}/")
|
||||
assert not pathlib.Path(file_path).exists()
|
||||
|
||||
response = client.get(f"/files/{file_id}/")
|
||||
response = await jwt_client.get(f"/files/{file_id}/")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_file_deletion_200_and_return_deleted_resource(client, tmp_path):
|
||||
async def test_file_deletion_200_and_return_deleted_resource(jwt_client, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("testtest")
|
||||
|
||||
with open(str(mock_file), "rb") as mock_file_stream:
|
||||
response = client.post("/files/", files={"file": mock_file_stream})
|
||||
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
|
||||
|
||||
response_data = response.json()
|
||||
file_id = response_data["id"]
|
||||
|
||||
response = client.delete(f"/files/{file_id}/")
|
||||
response = await jwt_client.delete(f"/files/{file_id}/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == response_data
|
||||
|
||||
|
||||
def test_file_downloads_200_and_return_file(client, tmp_path):
|
||||
async def test_file_downloads_200_and_return_file(jwt_client, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("testtest")
|
||||
|
||||
with open(str(mock_file), "rb") as mock_file_stream:
|
||||
response = client.post("/files/", files={"file": mock_file_stream})
|
||||
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
|
||||
|
||||
response_data = response.json()
|
||||
file_id = response_data["id"]
|
||||
|
||||
response = client.get(f"/files/{file_id}/content/")
|
||||
response = await jwt_client.get(f"/files/{file_id}/content/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == mock_file.read_text()
|
||||
|
||||
|
||||
def test_file_downloads_404_if_does_not_exist(client):
|
||||
async def test_file_downloads_404_if_does_not_exist(jwt_client):
|
||||
non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a"
|
||||
response = client.get(f"/files/{non_existent_id}/content/")
|
||||
response = await jwt_client.get(f"/files/{non_existent_id}/content/")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
|
Reference in a new issue