diff --git a/backend/.pylintrc b/backend/.pylintrc index d4d2b07..dd2f095 100644 --- a/backend/.pylintrc +++ b/backend/.pylintrc @@ -433,7 +433,9 @@ disable=raw-checker-failed, missing-function-docstring, missing-module-docstring, too-many-locals, - line-too-long + line-too-long, + too-few-public-methods, + fixme # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/backend/requirements_dev.in b/backend/requirements_dev.in index b06cca9..13b7301 100644 --- a/backend/requirements_dev.in +++ b/backend/requirements_dev.in @@ -1,5 +1,6 @@ -c requirements.txt +anyio~=3.7.0 black~=23.7.0 pylint~=2.17.0 pytest diff --git a/backend/requirements_dev.txt b/backend/requirements_dev.txt index 30d2963..a2b0bf2 100644 --- a/backend/requirements_dev.txt +++ b/backend/requirements_dev.txt @@ -1,6 +1,7 @@ anyio==3.7.1 # via # -c requirements.txt + # -r requirements_dev.in # httpcore astroid==2.15.6 # via pylint diff --git a/backend/rotini/auth/middleware.py b/backend/rotini/auth/middleware.py index 0b80869..f1f8eb7 100644 --- a/backend/rotini/auth/middleware.py +++ b/backend/rotini/auth/middleware.py @@ -1,23 +1,18 @@ """ Authentication & authorization middleware logic. - -This module is imported dynamically from `settings` to set up -middlewares with the `FastAPI` singleton. """ import logging import jwt.exceptions from fastapi import Request - -from main import app +from starlette.middleware.base import BaseHTTPMiddleware import auth.use_cases as auth_use_cases logger = logging.getLogger(__name__) -@app.middleware("http") -async def authentication_middleware(request: Request, call_next): +class AuthenticationMiddleware(BaseHTTPMiddleware): """ Decodes Authorization headers if present on the request and sets identifying fields in the request state. @@ -25,22 +20,23 @@ async def authentication_middleware(request: Request, call_next): This information is then leveraged by individual routes to determine authorization. """ - auth_header = request.headers.get("authorization") - decoded_token = None - if auth_header is not None: - _, token = auth_header.split(" ") - try: - decoded_token = auth_use_cases.decode_token(token) - except jwt.exceptions.ExpiredSignatureError as exc: - logger.exception(exc) + async def dispatch(self, request: Request, call_next): + auth_header = request.headers.get("authorization") + decoded_token = None - if decoded_token is not None: - logger.info(decoded_token) - request.state.user = { - "username": decoded_token["username"], - "user_id": decoded_token["user_id"], - } + if auth_header is not None: + _, token = auth_header.split(" ") + try: + decoded_token = auth_use_cases.decode_token(token) + except jwt.exceptions.ExpiredSignatureError as exc: + logger.exception(exc) - response = await call_next(request) - return response + if decoded_token is not None: + logger.info(decoded_token) + request.state.user = { + "username": decoded_token["username"], + "user_id": decoded_token["user_id"], + } + + return await call_next(request) diff --git a/backend/rotini/files/base.py b/backend/rotini/files/base.py new file mode 100644 index 0000000..ae8eed3 --- /dev/null +++ b/backend/rotini/files/base.py @@ -0,0 +1,13 @@ +import typing_extensions as typing + + +class FileRecord(typing.TypedDict): + """ + Database record associated with a file tracked + by the system. + """ + + id: str + size: int + path: str + filename: str diff --git a/backend/rotini/files/routes.py b/backend/rotini/files/routes.py index 420ca7a..8c59bc3 100644 --- a/backend/rotini/files/routes.py +++ b/backend/rotini/files/routes.py @@ -7,23 +7,39 @@ files that live in the system. import pathlib -from fastapi import APIRouter, HTTPException, UploadFile +from fastapi import APIRouter, HTTPException, UploadFile, Request from fastapi.responses import FileResponse import files.use_cases as files_use_cases - from settings import settings router = APIRouter(prefix="/files") -@router.get("/") -def list_files(): - return files_use_cases.get_all_file_records() +@router.get("/", status_code=200) +async def list_files(request: Request): + """ + Fetches all files owned by the logged-in user. + + 200 { [, ...] } + + If the user is logged in, file records that they + own are returned. + + 401 {} + + If the request is not authenticated, it fails. + """ + # FIXME: Temporarily fetching files belonging to the base user. + # to be resolved once users can log in. + current_user_id = ( + request.state.user["user_id"] if hasattr(request.state, "user") else 1 + ) + return files_use_cases.get_all_files_owned_by_user(current_user_id) @router.post("/", status_code=201) -async def upload_file(file: UploadFile) -> files_use_cases.FileRecord: +async def upload_file(request: Request, file: UploadFile) -> files_use_cases.FileRecord: """ Receives files uploaded by the user, saving them to disk and recording their existence in the database. @@ -39,8 +55,13 @@ async def upload_file(file: UploadFile) -> files_use_cases.FileRecord: with open(dest_path, "wb") as f: f.write(content) - - created_record = files_use_cases.create_file_record(str(dest_path), size) + # FIXME: Temporarily fetching files belonging to the base user. + # to be resolved once users can log in. + created_record = files_use_cases.create_file_record( + str(dest_path), + size, + request.state.user["user_id"] if hasattr(request.state, "user") else 1, + ) return created_record diff --git a/backend/rotini/files/use_cases.py b/backend/rotini/files/use_cases.py index 011a204..15903a6 100644 --- a/backend/rotini/files/use_cases.py +++ b/backend/rotini/files/use_cases.py @@ -12,22 +12,15 @@ import typing_extensions as typing from db import get_connection from settings import settings +from permissions.base import Permissions +from permissions.files import set_file_permission + from exceptions import DoesNotExist - -class FileRecord(typing.TypedDict): - """ - Database record associated with a file tracked - by the system. - """ - - id: str - size: int - path: str - filename: str +from files.base import FileRecord -def create_file_record(path: str, size: int) -> FileRecord: +def create_file_record(path: str, size: int, owner_id: int) -> FileRecord: """ Creates a record representing an uploaded file in the database. @@ -43,20 +36,34 @@ def create_file_record(path: str, size: int) -> FileRecord: inserted_id = cursor.fetchone()[0] + set_file_permission(inserted_id, owner_id, list(Permissions)) + filename = pathlib.Path(path).name return FileRecord(id=inserted_id, size=size, path=path, filename=filename) -def get_all_file_records() -> typing.Tuple[FileRecord]: - """ - Fetches all availables files from the database. +def get_all_files_owned_by_user(user_id: int) -> typing.Tuple[FileRecord]: """ + Gets all the file records owned by the user. + A file is considered owned if the user has all permissions on a given file. There + can be more than one owner to a file, but all files must have an owner. + """ rows = None with get_connection() as connection, connection.cursor() as cursor: - cursor.execute("SELECT * FROM files;") + cursor.execute( + """SELECT + f.* + from files f + join permissions_files pf + on f.id = pf.file_id + where + pf.user_id = %s + and pf.value = %s;""", + (user_id, sum(p.value for p in Permissions)), + ) rows = cursor.fetchall() if rows is None: diff --git a/backend/rotini/main.py b/backend/rotini/main.py index bb4fd17..a3d362f 100644 --- a/backend/rotini/main.py +++ b/backend/rotini/main.py @@ -1,11 +1,11 @@ """ Rotini: a self-hosted cloud storage & productivity app. """ -import importlib from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +import auth.middleware as auth_middleware import auth.routes as auth_routes import files.routes as files_routes @@ -21,16 +21,13 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) +app.add_middleware(auth_middleware.AuthenticationMiddleware) routers = [files_routes.router, auth_routes.router] -middlewares = ["auth.middleware"] for router in routers: app.include_router(router) -for middleware in middlewares: - importlib.import_module(middleware) - @app.get("/", status_code=204) def healthcheck(): diff --git a/backend/rotini/migrations/migration_2_permissions.py b/backend/rotini/migrations/migration_2_permissions.py new file mode 100644 index 0000000..7ed470c --- /dev/null +++ b/backend/rotini/migrations/migration_2_permissions.py @@ -0,0 +1,27 @@ +""" +Generated: 2023-08-27T11:56:17.800102 + +Message: Sets up permission-tracking on files +""" +UID = "3c755dd8-e02d-4a29-b4ee-2afa4d9b30d6" + +PARENT = "141faa0b-6868-4d07-a24b-b45f98d2809d" + +MESSAGE = "Sets up permission-tracking on files" + +UP_SQL = """CREATE TABLE + permissions_files +( + id bigserial PRIMARY KEY, + file_id uuid NOT NULL, + user_id bigint NOT NULL, + value bigint NOT NULL, + created_at timestamp DEFAULT now(), + updated_at timestamp DEFAULT now(), + CONSTRAINT file_fk FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE, + CONSTRAINT user_fk FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT unique_permission_per_file_per_user UNIQUE(file_id, user_id) +); +""" + +DOWN_SQL = """DROP TABLE permissions_files;""" diff --git a/backend/rotini/permissions/base.py b/backend/rotini/permissions/base.py new file mode 100644 index 0000000..8375221 --- /dev/null +++ b/backend/rotini/permissions/base.py @@ -0,0 +1,23 @@ +import enum + +import typing_extensions as typing + + +class Permissions(enum.Enum): + """ + Enumeration of individual permission bits. + + Complex permissions are composed by combining these + bits. + """ + + CAN_VIEW = 1 << 0 + CAN_DELETE = 1 << 1 + + +class FilePermission(typing.TypedDict): + """Representation of a permission applicable to a file+user pair""" + + file: str + user: int + value: typing.List[Permissions] diff --git a/backend/rotini/permissions/files.py b/backend/rotini/permissions/files.py new file mode 100644 index 0000000..6639a6a --- /dev/null +++ b/backend/rotini/permissions/files.py @@ -0,0 +1,26 @@ +import typing_extensions as typing + +from permissions.base import Permissions, FilePermission +from db import get_connection + + +def set_file_permission( + file_id: str, user_id: int, permissions: typing.List[Permissions] +) -> FilePermission: + """ + Given a file+user pair, creates a permission record with the + provided permission list. + """ + permission_value = sum(permission.value for permission in permissions) + + with get_connection() as connection, connection.cursor() as cursor: + cursor.execute( + "INSERT INTO permissions_files (user_id, file_id, value) VALUES (%s, %s, %s) RETURNING id;", + (user_id, file_id, permission_value), + ) + inserted_row = cursor.fetchone() + + if inserted_row is None: + raise RuntimeError("uh") + + return FilePermission(file=file_id, user=user_id, value=permissions) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 252f160..0038eae 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,24 +1,26 @@ +""" +Global fixtures + + +""" from fastapi.testclient import TestClient +import httpx import pytest -import unittest.mock - -from rotini.main import app -from rotini.db import get_connection +from main import app +from db import get_connection from settings import settings -@pytest.fixture(name="client") -def fixture_client(): - return TestClient(app) +@pytest.fixture +def anyio_backend(): + return "asyncio" -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=False) def reset_database(): - """ - Empties all user tables between tests. - """ - tables = ["files", "users"] + """Empties all user tables between tests.""" + tables = ["files", "users", "permissions_files"] with get_connection() as conn, conn.cursor() as cursor: for table in tables: @@ -26,7 +28,7 @@ def reset_database(): @pytest.fixture(autouse=True) -def set_storage_path(tmp_path, monkeypatch): +async def set_storage_path(tmp_path, monkeypatch): """ Ensures that files stored by tests are stored in temporary directories. @@ -38,17 +40,56 @@ def set_storage_path(tmp_path, monkeypatch): monkeypatch.setattr(settings, "STORAGE_ROOT", str(files_dir)) +@pytest.fixture(name="test_user_credentials") +def fixture_test_user_creds(): + """ + Test user credentials. + """ + return {"username": "testuser", "password": "testpassword"} + + +@pytest.fixture(name="test_user", autouse=True) +async def fixture_test_user(client_create_user, test_user_credentials): + """ + Sets up a test user using the `test_user_credentials` data. + """ + yield await client_create_user(test_user_credentials) + + +@pytest.fixture(name="no_auth_client") +async def fixture_no_auth_client(): + """HTTP client without any authentication""" + async with httpx.AsyncClient(app=app, base_url="http://test") as client: + yield client + + +@pytest.fixture(name="jwt_client") +async def fixture_jwt_client(client_log_in, test_user_credentials): + """HTTP client with test user authentication via JWT""" + response = await client_log_in(test_user_credentials) + auth_header = response.headers["authorization"] + + async with httpx.AsyncClient( + app=app, base_url="http://test", headers={"Authorization": auth_header} + ) as client: + yield client + + @pytest.fixture(name="client_log_in") -def fixture_client_log_in(client): - def _client_log_in(credentials): - return client.post("/auth/sessions/", json=credentials) +def fixture_client_log_in(no_auth_client): + """Logs in as the provided user""" + + async def _client_log_in(credentials): + return await no_auth_client.post("/auth/sessions/", json=credentials) return _client_log_in @pytest.fixture(name="client_create_user") -def fixture_client_create_user(client): - def _client_create_user(credentials): - return client.post("/auth/users/", json=credentials) +def fixture_client_create_user(no_auth_client): + """Creates a new user given credentials""" + + async def _client_create_user(credentials): + return await no_auth_client.post("/auth/users/", json=credentials) return _client_create_user diff --git a/backend/tests/test_auth_routes.py b/backend/tests/test_auth_routes.py index 5a752f1..683529f 100644 --- a/backend/tests/test_auth_routes.py +++ b/backend/tests/test_auth_routes.py @@ -1,36 +1,21 @@ import pytest -import jwt + +pytestmark = pytest.mark.anyio -@pytest.fixture(name="test_user_credentials") -def fixture_test_user_creds(): - """ - Test user credentials. - """ - return {"username": "testuser", "password": "testpassword"} - - -@pytest.fixture(name="test_user", autouse=True) -def fixture_test_user(client_create_user, test_user_credentials): - """ - Sets up a test user using the `test_user_credentials` data. - """ - yield client_create_user(test_user_credentials) - - -def test_create_user_returns_201_on_success(client_create_user): +async def test_create_user_returns_201_on_success(client_create_user): credentials = {"username": "newuser", "password": "test"} - response = client_create_user(credentials) + response = await client_create_user(credentials) assert response.status_code == 201 -def test_create_user_with_nonunique_username_fails(client_create_user): +async def test_create_user_with_nonunique_username_fails(client_create_user): credentials = {"username": "newuser", "password": "test"} - client_create_user(credentials) + await client_create_user(credentials) # Recreate the same user, name collision. - response = client_create_user(credentials) + response = await client_create_user(credentials) assert response.status_code == 400 @@ -43,18 +28,20 @@ def test_create_user_with_nonunique_username_fails(client_create_user): pytest.param({}, id="no_data"), ], ) -def test_create_user_requires_username_and_password_supplied( +async def test_create_user_requires_username_and_password_supplied( client_create_user, credentials ): - response = client_create_user(credentials) + response = await client_create_user(credentials) assert response.status_code == 422 -def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credentials): +async def test_log_in_returns_200_and_user_on_success( + client_log_in, test_user_credentials +): # The `test_user` fixture creates a user. - response = client_log_in(test_user_credentials) + response = await client_log_in(test_user_credentials) assert response.status_code == 200 @@ -63,7 +50,7 @@ def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credent assert returned["username"] == test_user_credentials["username"] -def test_log_in_attaches_identity_token_to_response_on_success( +async def test_log_in_attaches_identity_token_to_response_on_success( client_log_in, test_user_credentials ): # This test specifically needs to inspect the JWT, hence the need to access @@ -71,7 +58,7 @@ def test_log_in_attaches_identity_token_to_response_on_success( import auth.use_cases as auth_use_cases - response = client_log_in(test_user_credentials) + response = await client_log_in(test_user_credentials) returned_auth = response.headers.get("authorization") token = returned_auth.split(" ")[1] # Header of the form "Bearer " @@ -82,16 +69,18 @@ def test_log_in_attaches_identity_token_to_response_on_success( ) -def test_log_in_returns_401_on_wrong_password(client_log_in, test_user_credentials): - response = client_log_in( +async def test_log_in_returns_401_on_wrong_password( + client_log_in, test_user_credentials +): + response = await client_log_in( {"username": test_user_credentials["username"], "password": "sillystring"} ) assert response.status_code == 401 -def test_log_in_returns_401_on_nonexistent_user(client_log_in): - response = client_log_in({"username": "notauser", "password": "sillystring"}) +async def test_log_in_returns_401_on_nonexistent_user(client_log_in): + response = await client_log_in({"username": "notauser", "password": "sillystring"}) assert response.status_code == 401 @@ -104,7 +93,7 @@ def test_log_in_returns_401_on_nonexistent_user(client_log_in): pytest.param({}, id="no_data"), ], ) -def test_log_in_returns_422_on_invalid_input(client_log_in, credentials): - response = client_log_in(credentials) +async def test_log_in_returns_422_on_invalid_input(client_log_in, credentials): + response = await client_log_in(credentials) assert response.status_code == 422 diff --git a/backend/tests/test_files_routes.py b/backend/tests/test_files_routes.py index 6f76b26..5f26ad0 100644 --- a/backend/tests/test_files_routes.py +++ b/backend/tests/test_files_routes.py @@ -1,12 +1,16 @@ import pathlib +import pytest -def test_list_files_returns_registered_files_and_200(client, tmp_path): +pytestmark = pytest.mark.anyio + + +async def test_list_files_returns_registered_files_and_200(jwt_client, tmp_path): mock_file_1 = tmp_path / "test1.txt" mock_file_1.write_text("testtest") with open(str(mock_file_1), "rb") as mock_file_stream: - response = client.post("/files/", files={"file": mock_file_stream}) + response = await jwt_client.post("/files/", files={"file": mock_file_stream}) mock_file_1_data = response.json() @@ -14,104 +18,104 @@ def test_list_files_returns_registered_files_and_200(client, tmp_path): mock_file_2.write_text("testtest") with open(str(mock_file_2), "rb") as mock_file_stream: - response = client.post("/files/", files={"file": mock_file_stream}) + response = await jwt_client.post("/files/", files={"file": mock_file_stream}) mock_file_2_data = response.json() - response = client.get("/files/") + response = await jwt_client.get("/files/") assert response.status_code == 200 assert response.json() == [mock_file_1_data, mock_file_2_data] -def test_file_details_returns_specified_file_and_200(client, tmp_path): +async def test_file_details_returns_specified_file_and_200(jwt_client, tmp_path): mock_file = tmp_path / "test.txt" mock_file.write_text("testtest") with open(str(mock_file), "rb") as mock_file_stream: - response = client.post("/files/", files={"file": mock_file_stream}) + response = await jwt_client.post("/files/", files={"file": mock_file_stream}) response_data = response.json() created_file_id = response_data["id"] - response = client.get(f"/files/{created_file_id}/") + response = await jwt_client.get(f"/files/{created_file_id}/") assert response.status_code == 200 assert response.json() == response_data -def test_file_details_returns_404_if_does_not_exist(client): +async def test_file_details_returns_404_if_does_not_exist(jwt_client): non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a" - response = client.get(f"/files/{non_existent_id}/") + response = await jwt_client.get(f"/files/{non_existent_id}/") assert response.status_code == 404 -def test_file_deletion_returns_404_if_does_not_exist(client): +async def test_file_deletion_returns_404_if_does_not_exist(jwt_client): non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a" - response = client.delete(f"/files/{non_existent_id}/") + response = await jwt_client.delete(f"/files/{non_existent_id}/") assert response.status_code == 404 -def test_file_deletion_deletes_record_and_file(client, tmp_path): +async def test_file_deletion_deletes_record_and_file(jwt_client, tmp_path): mock_file = tmp_path / "test.txt" mock_file.write_text("testtest") with open(str(mock_file), "rb") as mock_file_stream: - response = client.post("/files/", files={"file": mock_file_stream}) + response = await jwt_client.post("/files/", files={"file": mock_file_stream}) response_data = response.json() file_id = response_data["id"] file_path = response_data["path"] assert pathlib.Path(file_path).exists() - response = client.get(f"/files/{file_id}/") + response = await jwt_client.get(f"/files/{file_id}/") assert response.status_code == 200 - client.delete(f"/files/{file_id}/") + await jwt_client.delete(f"/files/{file_id}/") assert not pathlib.Path(file_path).exists() - response = client.get(f"/files/{file_id}/") + response = await jwt_client.get(f"/files/{file_id}/") assert response.status_code == 404 -def test_file_deletion_200_and_return_deleted_resource(client, tmp_path): +async def test_file_deletion_200_and_return_deleted_resource(jwt_client, tmp_path): mock_file = tmp_path / "test.txt" mock_file.write_text("testtest") with open(str(mock_file), "rb") as mock_file_stream: - response = client.post("/files/", files={"file": mock_file_stream}) + response = await jwt_client.post("/files/", files={"file": mock_file_stream}) response_data = response.json() file_id = response_data["id"] - response = client.delete(f"/files/{file_id}/") + response = await jwt_client.delete(f"/files/{file_id}/") assert response.status_code == 200 assert response.json() == response_data -def test_file_downloads_200_and_return_file(client, tmp_path): +async def test_file_downloads_200_and_return_file(jwt_client, tmp_path): mock_file = tmp_path / "test.txt" mock_file.write_text("testtest") with open(str(mock_file), "rb") as mock_file_stream: - response = client.post("/files/", files={"file": mock_file_stream}) + response = await jwt_client.post("/files/", files={"file": mock_file_stream}) response_data = response.json() file_id = response_data["id"] - response = client.get(f"/files/{file_id}/content/") + response = await jwt_client.get(f"/files/{file_id}/content/") assert response.status_code == 200 assert response.text == mock_file.read_text() -def test_file_downloads_404_if_does_not_exist(client): +async def test_file_downloads_404_if_does_not_exist(jwt_client): non_existent_id = "06f02980-864d-4832-a894-2e9d2543a79a" - response = client.get(f"/files/{non_existent_id}/content/") + response = await jwt_client.get(f"/files/{non_existent_id}/content/") assert response.status_code == 404