feat(backend): create user + log in implementation (#31)
* feat(backend): create user table * build(backend): add argon2-cffi dependency * feat(backend): basic create user / login implementation * chore(backend): ignore needless lintrules * refactor(backend): user api+use cases clean up and docs * refactor(backend): reorganize into module * test(backend): login route coverage * refactor(backend): add request data schemas * test(backend): refactor client call fixtures * feat(backend): set up username uniqueness constraint * test(backend): update coverage for username uniqueness * chore(backend): missing dunderinit * chore(backend): linting
This commit is contained in:
parent
16bb6d3afe
commit
acdf1ca145
14 changed files with 365 additions and 10 deletions
|
@ -432,7 +432,8 @@ disable=raw-checker-failed,
|
|||
invalid-name,
|
||||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
too-many-locals
|
||||
too-many-locals,
|
||||
line-too-long
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
|
|
|
@ -3,3 +3,6 @@ uvicorn[standard]
|
|||
python-multipart
|
||||
psycopg2
|
||||
typing_extensions
|
||||
pydantic ~= 2.0
|
||||
|
||||
argon2-cffi~=23.1
|
||||
|
|
|
@ -4,6 +4,12 @@ anyio==3.7.1
|
|||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
argon2-cffi==23.1.0
|
||||
# via -r requirements.in
|
||||
argon2-cffi-bindings==21.2.0
|
||||
# via argon2-cffi
|
||||
cffi==1.15.1
|
||||
# via argon2-cffi-bindings
|
||||
click==8.1.6
|
||||
# via uvicorn
|
||||
exceptiongroup==1.1.2
|
||||
|
@ -18,6 +24,8 @@ idna==3.4
|
|||
# via anyio
|
||||
psycopg2==2.9.7
|
||||
# via -r requirements.in
|
||||
pycparser==2.21
|
||||
# via cffi
|
||||
pydantic==2.1.1
|
||||
# via fastapi
|
||||
pydantic-core==2.4.0
|
||||
|
|
0
backend/rotini/auth/__init__.py
Normal file
0
backend/rotini/auth/__init__.py
Normal file
22
backend/rotini/auth/base.py
Normal file
22
backend/rotini/auth/base.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""
|
||||
Class declarations and constants for the auth module.
|
||||
"""
|
||||
import pydantic
|
||||
|
||||
|
||||
class LoginRequestData(pydantic.BaseModel):
|
||||
"""Payload for login requests"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class CreateUserRequestData(pydantic.BaseModel):
|
||||
"""Payload for user creation"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class UsernameAlreadyExists(Exception):
|
||||
"""Signals a unique constraint violation on username values"""
|
62
backend/rotini/auth/routes.py
Normal file
62
backend/rotini/auth/routes.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from use_cases.exceptions import DoesNotExist
|
||||
|
||||
import auth.use_cases as auth_use_cases
|
||||
import auth.base as auth_base
|
||||
|
||||
router = APIRouter(prefix="/auth")
|
||||
|
||||
|
||||
@router.post("/users/", status_code=201)
|
||||
async def create_user(payload: auth_base.CreateUserRequestData):
|
||||
"""
|
||||
POST /auth/users/
|
||||
|
||||
{
|
||||
username: string
|
||||
password: string
|
||||
}
|
||||
|
||||
201 { <UserData> }
|
||||
|
||||
If the user is created successfully, the user object is returned.
|
||||
|
||||
400 {}
|
||||
|
||||
If the username already exists, or the password is not adequate,
|
||||
400 is returned.
|
||||
"""
|
||||
try:
|
||||
user = auth_use_cases.create_new_user(
|
||||
username=payload.username, raw_password=payload.password
|
||||
)
|
||||
except auth_base.UsernameAlreadyExists as exc:
|
||||
raise HTTPException(status_code=400) from exc
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/sessions/")
|
||||
async def log_in(payload: auth_base.LoginRequestData):
|
||||
"""
|
||||
Attempts to log a user in.
|
||||
|
||||
200 { <User> }
|
||||
|
||||
If the supplied credentials are correct, the user is returned.
|
||||
|
||||
401 {}
|
||||
|
||||
If the credentials are incorrect, immediate failure.
|
||||
"""
|
||||
|
||||
try:
|
||||
user = auth_use_cases.get_user(username=payload.username)
|
||||
except DoesNotExist as exc:
|
||||
raise HTTPException(status_code=401) from exc
|
||||
|
||||
if not auth_use_cases.validate_password_for_user(user["id"], payload.password):
|
||||
raise HTTPException(status_code=401)
|
||||
|
||||
return user
|
122
backend/rotini/auth/use_cases.py
Normal file
122
backend/rotini/auth/use_cases.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
User-related use cases.
|
||||
|
||||
Functions in this file are focused on users and passwords.
|
||||
"""
|
||||
import datetime
|
||||
import typing_extensions as typing
|
||||
|
||||
import argon2
|
||||
|
||||
from db import get_connection
|
||||
from use_cases.exceptions import DoesNotExist
|
||||
|
||||
import auth.base as auth_base
|
||||
|
||||
password_hasher = argon2.PasswordHasher()
|
||||
|
||||
|
||||
class User(typing.TypedDict):
|
||||
"""
|
||||
User representation.
|
||||
|
||||
The password hash is never included in these records and should
|
||||
not leave the database.
|
||||
"""
|
||||
|
||||
id: int
|
||||
username: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
password_updated_at: datetime.datetime
|
||||
|
||||
|
||||
def create_new_user(*, username: str, raw_password: str) -> User:
|
||||
"""
|
||||
Creates a new user record given a username and password.
|
||||
|
||||
The password is hashed (see `_hash_secret`) and the hash is stored.
|
||||
|
||||
If successful, returns a dictionary representing the user.
|
||||
"""
|
||||
password_hash = _hash_secret(raw_password)
|
||||
|
||||
with get_connection() as connection, connection.cursor() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO users (username, password_hash) VALUES (%s, %s) RETURNING id, username",
|
||||
(username, password_hash),
|
||||
)
|
||||
returned = cursor.fetchone()
|
||||
except Exception as exc:
|
||||
raise auth_base.UsernameAlreadyExists() from exc
|
||||
|
||||
inserted_id = returned[0]
|
||||
created_username = returned[1]
|
||||
|
||||
return User(
|
||||
id=inserted_id,
|
||||
username=created_username,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
password_updated_at=datetime.datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
def _hash_secret(secret: str) -> str:
|
||||
"""
|
||||
Produces a hash of the given secret.
|
||||
"""
|
||||
return password_hasher.hash(secret)
|
||||
|
||||
|
||||
def get_user(
|
||||
*, username: str = None, user_id: int = None
|
||||
) -> typing.Union[typing.NoReturn, User]:
|
||||
"""
|
||||
Retrieves a user record, if one exists, for the given user.
|
||||
|
||||
Querying can be done via username or user ID. The first one supplied, in this
|
||||
order, is used and any other values are ignored.
|
||||
"""
|
||||
with get_connection() as connection, connection.cursor() as cursor:
|
||||
if username is not None:
|
||||
cursor.execute(
|
||||
"SELECT id, username, created_at, updated_at, password_updated_at FROM users WHERE username = %s;",
|
||||
(username,),
|
||||
)
|
||||
elif user_id is not None:
|
||||
cursor.execute(
|
||||
"SELECT id, username, created_at, updated_at, password_updated_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
fetched = cursor.fetchone()
|
||||
|
||||
if fetched is None:
|
||||
raise DoesNotExist()
|
||||
|
||||
return User(
|
||||
id=fetched[0],
|
||||
username=fetched[1],
|
||||
created_at=fetched[2],
|
||||
updated_at=fetched[3],
|
||||
password_updated_at=fetched[4],
|
||||
)
|
||||
|
||||
|
||||
def validate_password_for_user(user_id: int, raw_password: str) -> bool:
|
||||
"""
|
||||
Validates whether a password is correct for the given user.
|
||||
|
||||
Always returns a boolean representing whether it was a match or not.
|
||||
"""
|
||||
try:
|
||||
with get_connection() as connection, connection.cursor() as cursor:
|
||||
cursor.execute("SELECT password_hash FROM users WHERE id = %s", (user_id,))
|
||||
fetched = cursor.fetchone()
|
||||
|
||||
current_secret_hash = fetched[0]
|
||||
return password_hasher.verify(current_secret_hash, raw_password)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return False
|
|
@ -4,7 +4,8 @@ Rotini: a self-hosted cloud storage & productivity app.
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
import api.files
|
||||
import auth.routes as auth_routes
|
||||
import api.files as files_routes
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
@ -18,7 +19,10 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api.files.router)
|
||||
routers = [files_routes.router, auth_routes.router]
|
||||
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.get("/", status_code=204)
|
||||
|
|
25
backend/rotini/migrations/migration_1_user_table.py
Normal file
25
backend/rotini/migrations/migration_1_user_table.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
"""
|
||||
Generated: 2023-08-19T23:04:28.163820
|
||||
|
||||
Message: None
|
||||
"""
|
||||
UID = "141faa0b-6868-4d07-a24b-b45f98d2809d"
|
||||
|
||||
PARENT = "06f02980-864d-4832-a894-2e9d2543a79a"
|
||||
|
||||
MESSAGE = "Creates the user table."
|
||||
|
||||
UP_SQL = """CREATE TABLE
|
||||
users
|
||||
(
|
||||
id bigserial PRIMARY KEY,
|
||||
username varchar(64) NOT NULL,
|
||||
password_hash varchar(128) NOT NULL,
|
||||
created_at timestamp DEFAULT now(),
|
||||
updated_at timestamp DEFAULT now(),
|
||||
password_updated_at timestamp DEFAULT now(),
|
||||
CONSTRAINT unique_username UNIQUE(username)
|
||||
)
|
||||
"""
|
||||
|
||||
DOWN_SQL = """DROP TABLE users;"""
|
0
backend/rotini/use_cases/__init__.py
Normal file
0
backend/rotini/use_cases/__init__.py
Normal file
4
backend/rotini/use_cases/exceptions.py
Normal file
4
backend/rotini/use_cases/exceptions.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
class DoesNotExist(Exception):
|
||||
"""
|
||||
General purpose exception signalling a failure to find a database record.
|
||||
"""
|
|
@ -11,12 +11,7 @@ import typing_extensions as typing
|
|||
|
||||
from db import get_connection
|
||||
from settings import settings
|
||||
|
||||
|
||||
class DoesNotExist(Exception):
|
||||
"""
|
||||
General purpose exception signalling a failure to find a database record.
|
||||
"""
|
||||
from use_cases.exceptions import DoesNotExist
|
||||
|
||||
|
||||
class FileRecord(typing.TypedDict):
|
||||
|
|
|
@ -18,8 +18,11 @@ def reset_database():
|
|||
"""
|
||||
Empties all user tables between tests.
|
||||
"""
|
||||
tables = ["files", "users"]
|
||||
|
||||
with get_connection() as conn, conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM files;")
|
||||
for table in tables:
|
||||
cursor.execute("DELETE FROM " + table + ";")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -33,3 +36,19 @@ def set_storage_path(tmp_path, monkeypatch):
|
|||
files_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(settings, "STORAGE_ROOT", str(files_dir))
|
||||
|
||||
|
||||
@pytest.fixture(name="client_log_in")
|
||||
def fixture_client_log_in(client):
|
||||
def _client_log_in(credentials):
|
||||
return client.post("/auth/sessions/", json=credentials)
|
||||
|
||||
return _client_log_in
|
||||
|
||||
|
||||
@pytest.fixture(name="client_create_user")
|
||||
def fixture_client_create_user(client):
|
||||
def _client_create_user(credentials):
|
||||
return client.post("/auth/users/", json=credentials)
|
||||
|
||||
return _client_create_user
|
||||
|
|
90
backend/tests/test_auth_routes.py
Normal file
90
backend/tests/test_auth_routes.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(name="test_user_credentials")
|
||||
def fixture_test_user_creds():
|
||||
"""
|
||||
Test user credentials.
|
||||
"""
|
||||
return {"username": "testuser", "password": "testpassword"}
|
||||
|
||||
|
||||
@pytest.fixture(name="test_user", autouse=True)
|
||||
def fixture_test_user(client_create_user, test_user_credentials):
|
||||
"""
|
||||
Sets up a test user using the `test_user_credentials` data.
|
||||
"""
|
||||
yield client_create_user(test_user_credentials)
|
||||
|
||||
|
||||
def test_create_user_returns_201_on_success(client_create_user):
|
||||
credentials = {"username": "newuser", "password": "test"}
|
||||
response = client_create_user(credentials)
|
||||
|
||||
assert response.status_code == 201
|
||||
|
||||
|
||||
def test_create_user_with_nonunique_username_fails(client_create_user):
|
||||
credentials = {"username": "newuser", "password": "test"}
|
||||
client_create_user(credentials)
|
||||
|
||||
# Recreate the same user, name collision.
|
||||
response = client_create_user(credentials)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"credentials",
|
||||
[
|
||||
pytest.param({"username": "test"}, id="username_only"),
|
||||
pytest.param({"password": "test"}, id="password_only"),
|
||||
pytest.param({}, id="no_data"),
|
||||
],
|
||||
)
|
||||
def test_create_user_requires_username_and_password_supplied(
|
||||
client_create_user, credentials
|
||||
):
|
||||
response = client_create_user(credentials)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_log_in_returns_200_and_user_on_success(client_log_in, test_user_credentials):
|
||||
# The `test_user` fixture creates a user.
|
||||
|
||||
response = client_log_in(test_user_credentials)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
returned = response.json()
|
||||
|
||||
assert returned["username"] == test_user_credentials["username"]
|
||||
|
||||
|
||||
def test_log_in_returns_401_on_wrong_password(client_log_in, test_user_credentials):
|
||||
response = client_log_in(
|
||||
{"username": test_user_credentials["username"], "password": "sillystring"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_log_in_returns_401_on_nonexistent_user(client_log_in):
|
||||
response = client_log_in({"username": "notauser", "password": "sillystring"})
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"credentials",
|
||||
[
|
||||
pytest.param({"username": "test"}, id="username_only"),
|
||||
pytest.param({"password": "test"}, id="password_only"),
|
||||
pytest.param({}, id="no_data"),
|
||||
],
|
||||
)
|
||||
def test_log_in_returns_422_on_invalid_input(client_log_in, credentials):
|
||||
response = client_log_in(credentials)
|
||||
|
||||
assert response.status_code == 422
|
Reference in a new issue