feat(backend): file permissions and ownership (#38)

* feat(backend): file permissions and utilities

* feat(backend): file upload/list permissions

* refactor(backend): class-based middleware to play nicer with tests

* test(backend): async fixtures + proper foreign key cascade

* fix(backend): temporarily bypass auth
This commit is contained in:
Marc 2023-08-28 01:21:49 -04:00 committed by GitHub
parent d31f73c66a
commit 23248d0277
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 278 additions and 130 deletions

View file

@ -433,7 +433,9 @@ disable=raw-checker-failed,
missing-function-docstring,
missing-module-docstring,
too-many-locals,
line-too-long
line-too-long,
too-few-public-methods,
fixme
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option

View file

@ -1,5 +1,6 @@
-c requirements.txt
anyio~=3.7.0
black~=23.7.0
pylint~=2.17.0
pytest

View file

@ -1,6 +1,7 @@
anyio==3.7.1
# via
# -c requirements.txt
# -r requirements_dev.in
# httpcore
astroid==2.15.6
# via pylint

View file

@ -1,23 +1,18 @@
"""
Authentication & authorization middleware logic.
This module is imported dynamically from `settings` to set up
middlewares with the `FastAPI` singleton.
"""
import logging
import jwt.exceptions
from fastapi import Request
from main import app
from starlette.middleware.base import BaseHTTPMiddleware
import auth.use_cases as auth_use_cases
logger = logging.getLogger(__name__)
@app.middleware("http")
async def authentication_middleware(request: Request, call_next):
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""
Decodes Authorization headers if present on the request and sets
identifying fields in the request state.
@ -25,22 +20,23 @@ async def authentication_middleware(request: Request, call_next):
This information is then leveraged by individual routes to determine
authorization.
"""
auth_header = request.headers.get("authorization")
decoded_token = None
if auth_header is not None:
_, token = auth_header.split(" ")
try:
decoded_token = auth_use_cases.decode_token(token)
except jwt.exceptions.ExpiredSignatureError as exc:
logger.exception(exc)
async def dispatch(self, request: Request, call_next):
auth_header = request.headers.get("authorization")
decoded_token = None
if decoded_token is not None:
logger.info(decoded_token)
request.state.user = {
"username": decoded_token["username"],
"user_id": decoded_token["user_id"],
}
if auth_header is not None:
_, token = auth_header.split(" ")
try:
decoded_token = auth_use_cases.decode_token(token)
except jwt.exceptions.ExpiredSignatureError as exc:
logger.exception(exc)
response = await call_next(request)
return response
if decoded_token is not None:
logger.info(decoded_token)
request.state.user = {
"username": decoded_token["username"],
"user_id": decoded_token["user_id"],
}
return await call_next(request)

View file

@ -0,0 +1,13 @@
import typing_extensions as typing
class FileRecord(typing.TypedDict):
"""
Database record associated with a file tracked
by the system.
"""
id: str
size: int
path: str
filename: str

View file

@ -7,23 +7,39 @@ files that live in the system.
import pathlib
from fastapi import APIRouter, HTTPException, UploadFile
from fastapi import APIRouter, HTTPException, UploadFile, Request
from fastapi.responses import FileResponse
import files.use_cases as files_use_cases
from settings import settings
router = APIRouter(prefix="/files")
@router.get("/")
def list_files():
return files_use_cases.get_all_file_records()
@router.get("/", status_code=200)
async def list_files(request: Request):
"""
Fetches all files owned by the logged-in user.
200 { [<FileRecord>, ...] }
If the user is logged in, file records that they
own are returned.
401 {}
If the request is not authenticated, it fails.
"""
# FIXME: Temporarily fetching files belonging to the base user.
# to be resolved once users can log in.
current_user_id = (
request.state.user["user_id"] if hasattr(request.state, "user") else 1
)
return files_use_cases.get_all_files_owned_by_user(current_user_id)
@router.post("/", status_code=201)
async def upload_file(file: UploadFile) -> files_use_cases.FileRecord:
async def upload_file(request: Request, file: UploadFile) -> files_use_cases.FileRecord:
"""
Receives files uploaded by the user, saving them to disk and
recording their existence in the database.
@ -39,8 +55,13 @@ async def upload_file(file: UploadFile) -> files_use_cases.FileRecord:
with open(dest_path, "wb") as f:
f.write(content)
created_record = files_use_cases.create_file_record(str(dest_path), size)
# FIXME: Temporarily fetching files belonging to the base user.
# to be resolved once users can log in.
created_record = files_use_cases.create_file_record(
str(dest_path),
size,
request.state.user["user_id"] if hasattr(request.state, "user") else 1,
)
return created_record

View file

@ -12,22 +12,15 @@ import typing_extensions as typing
from db import get_connection
from settings import settings
from permissions.base import Permissions
from permissions.files import set_file_permission
from exceptions import DoesNotExist
class FileRecord(typing.TypedDict):
"""
Database record associated with a file tracked
by the system.
"""
id: str
size: int
path: str
filename: str
from files.base import FileRecord
def create_file_record(path: str, size: int) -> FileRecord:
def create_file_record(path: str, size: int, owner_id: int) -> FileRecord:
"""
Creates a record representing an uploaded file in the database.
@ -43,20 +36,34 @@ def create_file_record(path: str, size: int) -> FileRecord:
inserted_id = cursor.fetchone()[0]
set_file_permission(inserted_id, owner_id, list(Permissions))
filename = pathlib.Path(path).name
return FileRecord(id=inserted_id, size=size, path=path, filename=filename)
def get_all_file_records() -> typing.Tuple[FileRecord]:
"""
Fetches all availables files from the database.
def get_all_files_owned_by_user(user_id: int) -> typing.Tuple[FileRecord]:
"""
Gets all the file records owned by the user.
A file is considered owned if the user has all permissions on a given file. There
can be more than one owner to a file, but all files must have an owner.
"""
rows = None
with get_connection() as connection, connection.cursor() as cursor:
cursor.execute("SELECT * FROM files;")
cursor.execute(
"""SELECT
f.*
from files f
join permissions_files pf
on f.id = pf.file_id
where
pf.user_id = %s
and pf.value = %s;""",
(user_id, sum(p.value for p in Permissions)),
)
rows = cursor.fetchall()
if rows is None:

View file

@ -1,11 +1,11 @@
"""
Rotini: a self-hosted cloud storage & productivity app.
"""
import importlib
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import auth.middleware as auth_middleware
import auth.routes as auth_routes
import files.routes as files_routes
@ -21,16 +21,13 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(auth_middleware.AuthenticationMiddleware)
routers = [files_routes.router, auth_routes.router]
middlewares = ["auth.middleware"]
for router in routers:
app.include_router(router)
for middleware in middlewares:
importlib.import_module(middleware)
@app.get("/", status_code=204)
def healthcheck():

View file

@ -0,0 +1,27 @@
"""
Generated: 2023-08-27T11:56:17.800102
Message: Sets up permission-tracking on files
"""
UID = "3c755dd8-e02d-4a29-b4ee-2afa4d9b30d6"
PARENT = "141faa0b-6868-4d07-a24b-b45f98d2809d"
MESSAGE = "Sets up permission-tracking on files"
UP_SQL = """CREATE TABLE
permissions_files
(
id bigserial PRIMARY KEY,
file_id uuid NOT NULL,
user_id bigint NOT NULL,
value bigint NOT NULL,
created_at timestamp DEFAULT now(),
updated_at timestamp DEFAULT now(),
CONSTRAINT file_fk FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE,
CONSTRAINT user_fk FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
CONSTRAINT unique_permission_per_file_per_user UNIQUE(file_id, user_id)
);
"""
DOWN_SQL = """DROP TABLE permissions_files;"""

View file

@ -0,0 +1,23 @@
import enum
import typing_extensions as typing
class Permissions(enum.Enum):
"""
Enumeration of individual permission bits.
Complex permissions are composed by combining these
bits.
"""
CAN_VIEW = 1 << 0
CAN_DELETE = 1 << 1
class FilePermission(typing.TypedDict):
"""Representation of a permission applicable to a file+user pair"""
file: str
user: int
value: typing.List[Permissions]

View file

@ -0,0 +1,26 @@
import typing_extensions as typing
from permissions.base import Permissions, FilePermission
from db import get_connection
def set_file_permission(
file_id: str, user_id: int, permissions: typing.List[Permissions]
) -> FilePermission:
"""
Given a file+user pair, creates a permission record with the
provided permission list.
"""
permission_value = sum(permission.value for permission in permissions)
with get_connection() as connection, connection.cursor() as cursor:
cursor.execute(
"INSERT INTO permissions_files (user_id, file_id, value) VALUES (%s, %s, %s) RETURNING id;",
(user_id, file_id, permission_value),
)
inserted_row = cursor.fetchone()
if inserted_row is None:
raise RuntimeError("uh")
return FilePermission(file=file_id, user=user_id, value=permissions)

View file

@ -1,24 +1,26 @@
"""
Global fixtures
"""
from fastapi.testclient import TestClient
import httpx
import pytest
import unittest.mock
from rotini.main import app
from rotini.db import get_connection
from main import app
from db import get_connection
from settings import settings
@pytest.fixture(name="client")
def fixture_client():
return TestClient(app)
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.fixture(autouse=True)
@pytest.fixture(autouse=False)
def reset_database():
"""
Empties all user tables between tests.
"""
tables = ["files", "users"]
"""Empties all user tables between tests."""
tables = ["files", "users", "permissions_files"]
with get_connection() as conn, conn.cursor() as cursor:
for table in tables:
@ -26,7 +28,7 @@ def reset_database():
@pytest.fixture(autouse=True)
def set_storage_path(tmp_path, monkeypatch):
async def set_storage_path(tmp_path, monkeypatch):
"""
Ensures that files stored by tests are stored
in temporary directories.
@ -38,17 +40,56 @@ def set_storage_path(tmp_path, monkeypatch):
monkeypatch.setattr(settings, "STORAGE_ROOT", str(files_dir))
@pytest.fixture(name="test_user_credentials")
def fixture_test_user_creds():
"""
Test user credentials.
"""
return {"username": "testuser", "password": "testpassword"}
@pytest.fixture(name="test_user", autouse=True)
async def fixture_test_user(client_create_user, test_user_credentials):
"""
Sets up a test user using the `test_user_credentials` data.
"""
yield await client_create_user(test_user_credentials)
@pytest.fixture(name="no_auth_client")
async def fixture_no_auth_client():
"""HTTP client without any authentication"""
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
yield client
@pytest.fixture(name="jwt_client")
async def fixture_jwt_client(client_log_in, test_user_credentials):
"""HTTP client with test user authentication via JWT"""
response = await client_log_in(test_user_credentials)
auth_header = response.headers["authorization"]
async with httpx.AsyncClient(
app=app, base_url="http://test", headers={"Authorization": auth_header}
) as client:
yield client
@pytest.fixture(name="client_log_in")
def fixture_client_log_in(client):
def _client_log_in(credentials):
return client.post("/auth/sessions/", json=credentials)
def fixture_client_log_in(no_auth_client):
"""Logs in as the provided user"""
async def _client_log_in(credentials):
return await no_auth_client.post("/auth/sessions/", json=credentials)
return _client_log_in
@pytest.fixture(name="client_create_user")
def fixture_client_create_user(client):
def _client_create_user(credentials):
return client.post("/auth/users/", json=credentials)
def fixture_client_create_user(no_auth_client):
"""Creates a new user given credentials"""
async def _client_create_user(credentials):
return await no_auth_client.post("/auth/users/", json=credentials)
return _client_create_user

View file

@ -1,36 +1,21 @@
import pytest
import jwt
pytestmark = pytest.mark.anyio
@pytest.fixture(name="test_user_credentials")
def fixture_test_user_creds():
"""
Test user credentials.
"""
return {"username": "testuser", "password": "testpassword"}
@pytest.fixture(name="test_user", autouse=True)
def fixture_test_user(client_create_user, test_user_credentials):
"""
Sets up a test user using the `test_user_credentials` data.
"""
yield client_create_user(test_user_credentials)
def test_create_user_returns_201_on_success(client_create_user):
async def test_create_user_returns_201_on_success(client_create_user):
credentials = {"username": "newuser", "password": "test"}
response = client_create_user(credentials)
response = await client_create_user(credentials)
assert response.status_code == 201
def test_create_user_with_nonunique_username_fails(client_create_user):
async def test_create_user_with_nonunique_username_fails(client_create_user):
credentials = {"username": "newuser", "password": "test"}
client_create_user(credentials)
await client_create_user(credentials)
# Recreate the same user, name collision.
response = client_create_user(credentials)
response = await client_create_user(credentials)
assert response.status_code == 400
@ -43,18 +28,20 @@ def test_create_user_with_nonunique_username_fails(client_create_user):
pytest.param({}, id="no_data"),
],
)
def test_create_user_requires_username_and_password_supplied(
async def test_create_user_requires_username_and_password_supplied(
client_create_user, credentials
):
response = client_create_user(credentials)
response = await client_create_user(credentials)
assert response.status_code == 422
def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credentials):
async def test_log_in_returns_200_and_user_on_success(
client_log_in, test_user_credentials
):
# The `test_user` fixture creates a user.
response = client_log_in(test_user_credentials)
response = await client_log_in(test_user_credentials)
assert response.status_code == 200
@ -63,7 +50,7 @@ def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credent
assert returned["username"] == test_user_credentials["username"]
def test_log_in_attaches_identity_token_to_response_on_success(
async def test_log_in_attaches_identity_token_to_response_on_success(
client_log_in, test_user_credentials
):
# This test specifically needs to inspect the JWT, hence the need to access
@ -71,7 +58,7 @@ def test_log_in_attaches_identity_token_to_response_on_success(
import auth.use_cases as auth_use_cases
response = client_log_in(test_user_credentials)
response = await client_log_in(test_user_credentials)
returned_auth = response.headers.get("authorization")
token = returned_auth.split(" ")[1] # Header of the form "Bearer <token>"
@ -82,16 +69,18 @@ def test_log_in_attaches_identity_token_to_response_on_success(
)
def test_log_in_returns_401_on_wrong_password(client_log_in, test_user_credentials):
response = client_log_in(
async def test_log_in_returns_401_on_wrong_password(
client_log_in, test_user_credentials
):
response = await client_log_in(
{"username": test_user_credentials["username"], "password": "sillystring"}
)
assert response.status_code == 401
def test_log_in_returns_401_on_nonexistent_user(client_log_in):
response = client_log_in({"username": "notauser", "password": "sillystring"})
async def test_log_in_returns_401_on_nonexistent_user(client_log_in):
response = await client_log_in({"username": "notauser", "password": "sillystring"})
assert response.status_code == 401
@ -104,7 +93,7 @@ def test_log_in_returns_401_on_nonexistent_user(client_log_in):
pytest.param({}, id="no_data"),
],
)
def test_log_in_returns_422_on_invalid_input(client_log_in, credentials):
response = client_log_in(credentials)
async def test_log_in_returns_422_on_invalid_input(client_log_in, credentials):
response = await client_log_in(credentials)
assert response.status_code == 422

View file

@ -1,12 +1,16 @@
import pathlib
import pytest
def test_list_files_returns_registered_files_and_200(client, tmp_path):
pytestmark = pytest.mark.anyio
async def test_list_files_returns_registered_files_and_200(jwt_client, tmp_path):
mock_file_1 = tmp_path / "test1.txt"
mock_file_1.write_text("testtest")
with open(str(mock_file_1), "rb") as mock_file_stream:
response = client.post("/files/", files={"file": mock_file_stream})
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
mock_file_1_data = response.json()
@ -14,104 +18,104 @@ def test_list_files_returns_registered_files_and_200(client, tmp_path):
mock_file_2.write_text("testtest")
with open(str(mock_file_2), "rb") as mock_file_stream:
response = client.post("/files/", files={"file": mock_file_stream})
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
mock_file_2_data = response.json()
response = client.get("/files/")
response = await jwt_client.get("/files/")
assert response.status_code == 200
assert response.json() == [mock_file_1_data, mock_file_2_data]
def test_file_details_returns_specified_file_and_200(client, tmp_path):
async def test_file_details_returns_specified_file_and_200(jwt_client, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("testtest")
with open(str(mock_file), "rb") as mock_file_stream:
response = client.post("/files/", files={"file": mock_file_stream})
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
response_data = response.json()
created_file_id = response_data["id"]
response = client.get(f"/files/{created_file_id}/")
response = await jwt_client.get(f"/files/{created_file_id}/")
assert response.status_code == 200
assert response.json() == response_data
def test_file_details_returns_404_if_does_not_exist(client):
async def test_file_details_returns_404_if_does_not_exist(jwt_client):
non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a"
response = client.get(f"/files/{non_existent_id}/")
response = await jwt_client.get(f"/files/{non_existent_id}/")
assert response.status_code == 404
def test_file_deletion_returns_404_if_does_not_exist(client):
async def test_file_deletion_returns_404_if_does_not_exist(jwt_client):
non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a"
response = client.delete(f"/files/{non_existent_id}/")
response = await jwt_client.delete(f"/files/{non_existent_id}/")
assert response.status_code == 404
def test_file_deletion_deletes_record_and_file(client, tmp_path):
async def test_file_deletion_deletes_record_and_file(jwt_client, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("testtest")
with open(str(mock_file), "rb") as mock_file_stream:
response = client.post("/files/", files={"file": mock_file_stream})
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
response_data = response.json()
file_id = response_data["id"]
file_path = response_data["path"]
assert pathlib.Path(file_path).exists()
response = client.get(f"/files/{file_id}/")
response = await jwt_client.get(f"/files/{file_id}/")
assert response.status_code == 200
client.delete(f"/files/{file_id}/")
await jwt_client.delete(f"/files/{file_id}/")
assert not pathlib.Path(file_path).exists()
response = client.get(f"/files/{file_id}/")
response = await jwt_client.get(f"/files/{file_id}/")
assert response.status_code == 404
def test_file_deletion_200_and_return_deleted_resource(client, tmp_path):
async def test_file_deletion_200_and_return_deleted_resource(jwt_client, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("testtest")
with open(str(mock_file), "rb") as mock_file_stream:
response = client.post("/files/", files={"file": mock_file_stream})
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
response_data = response.json()
file_id = response_data["id"]
response = client.delete(f"/files/{file_id}/")
response = await jwt_client.delete(f"/files/{file_id}/")
assert response.status_code == 200
assert response.json() == response_data
def test_file_downloads_200_and_return_file(client, tmp_path):
async def test_file_downloads_200_and_return_file(jwt_client, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("testtest")
with open(str(mock_file), "rb") as mock_file_stream:
response = client.post("/files/", files={"file": mock_file_stream})
response = await jwt_client.post("/files/", files={"file": mock_file_stream})
response_data = response.json()
file_id = response_data["id"]
response = client.get(f"/files/{file_id}/content/")
response = await jwt_client.get(f"/files/{file_id}/content/")
assert response.status_code == 200
assert response.text == mock_file.read_text()
def test_file_downloads_404_if_does_not_exist(client):
async def test_file_downloads_404_if_does_not_exist(jwt_client):
non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a"
response = client.get(f"/files/{non_existent_id}/content/")
response = await jwt_client.get(f"/files/{non_existent_id}/content/")
assert response.status_code == 404