Merge pull request #113 from mcataford/feat/auth-token-refresh
Feat/auth token refresh
This commit is contained in:
commit
2fb7fa3a59
15 changed files with 553 additions and 22 deletions
|
@ -32,7 +32,7 @@ tasks:
|
||||||
test:
|
test:
|
||||||
desc: "Run the test suites."
|
desc: "Run the test suites."
|
||||||
deps: [bootstrap]
|
deps: [bootstrap]
|
||||||
cmd: $SHELL script/test
|
cmd: $SHELL script/test {{ .CLI_ARGS }}
|
||||||
dotenv:
|
dotenv:
|
||||||
- ../backend-test.env
|
- ../backend-test.env
|
||||||
lock-deps:
|
lock-deps:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import datetime
|
||||||
|
|
||||||
import django.contrib.auth
|
import django.contrib.auth
|
||||||
from rest_framework import authentication
|
from rest_framework import authentication
|
||||||
|
@ -15,15 +16,13 @@ class RevokedTokenException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class JwtAuthentication(authentication.BaseAuthentication):
|
class JwtAuthenticationBase(authentication.BaseAuthentication):
|
||||||
"""
|
@property
|
||||||
Authentication class handling JWTs attached to requests via cookies.
|
def allow_expired(self):
|
||||||
|
return False
|
||||||
|
|
||||||
A JWT is only accepted if it's not expired (i.e. can be decoded) and if
|
def _get_decoded_token(self, token: str):
|
||||||
it has not been revoked (as per AuthenticationToken records). A revoked
|
return identity.jwt.decode_token(token, allow_expired=self.allow_expired)
|
||||||
token is declined even if the token itself has not expired yet and would
|
|
||||||
otherwise be valid.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def authenticate(self, request):
|
def authenticate(self, request):
|
||||||
jwt_cookie = request.COOKIES.get("jwt")
|
jwt_cookie = request.COOKIES.get("jwt")
|
||||||
|
@ -33,7 +32,7 @@ class JwtAuthentication(authentication.BaseAuthentication):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoded_token = identity.jwt.decode_token(jwt_cookie)
|
decoded_token = self._get_decoded_token(jwt_cookie)
|
||||||
|
|
||||||
logger.info("Token: %s\nDecoded token: %s", jwt_cookie, decoded_token)
|
logger.info("Token: %s\nDecoded token: %s", jwt_cookie, decoded_token)
|
||||||
|
|
||||||
|
@ -45,9 +44,42 @@ class JwtAuthentication(authentication.BaseAuthentication):
|
||||||
raise RevokedTokenException("Revoked tokens cannot be used")
|
raise RevokedTokenException("Revoked tokens cannot be used")
|
||||||
|
|
||||||
request.session["token_id"] = decoded_token["token_id"]
|
request.session["token_id"] = decoded_token["token_id"]
|
||||||
|
request.session["expired"] = (
|
||||||
|
decoded_token["exp"]
|
||||||
|
< datetime.datetime.now(datetime.timezone.utc).timestamp()
|
||||||
|
)
|
||||||
return user, None
|
return user, None
|
||||||
|
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
logger.exception(e, extra={"authorization_provided": jwt_cookie})
|
logger.exception(e, extra={"authorization_provided": jwt_cookie})
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
class JwtAuthenticationAllowExpired(JwtAuthenticationBase):
|
||||||
|
"""
|
||||||
|
Authentication class handling JWTs attached to requests via cookies.
|
||||||
|
|
||||||
|
A JWT is only accepted if it's not expired (i.e. can be decoded) and if
|
||||||
|
it has not been revoked (as per AuthenticationToken records). A revoked
|
||||||
|
token is declined even if the token itself has not expired yet and would
|
||||||
|
otherwise be valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_expired(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class JwtAuthentication(JwtAuthenticationBase):
|
||||||
|
"""
|
||||||
|
Authentication class handling JWTs attached to requests via cookies.
|
||||||
|
|
||||||
|
A JWT is only accepted if it's not expired (i.e. can be decoded) and if
|
||||||
|
it has not been revoked (as per AuthenticationToken records). A revoked
|
||||||
|
token is declined even if the token itself has not expired yet and would
|
||||||
|
otherwise be valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_expired(self):
|
||||||
|
return False
|
||||||
|
|
|
@ -26,7 +26,7 @@ def generate_token_for_user(user_id: int) -> tuple[str, TokenData]:
|
||||||
|
|
||||||
token_data = {
|
token_data = {
|
||||||
"exp": (
|
"exp": (
|
||||||
datetime.datetime.now()
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
+ datetime.timedelta(seconds=django.conf.settings.JWT_EXPIRATION)
|
+ datetime.timedelta(seconds=django.conf.settings.JWT_EXPIRATION)
|
||||||
).timestamp(),
|
).timestamp(),
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -41,16 +41,25 @@ def generate_token_for_user(user_id: int) -> tuple[str, TokenData]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(token: str, allow_expired: bool = False):
|
||||||
token: str,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Decodes the given token.
|
Decodes the given token.
|
||||||
|
|
||||||
This may raise if the token is expired or invalid.
|
This may raise if the token is expired or invalid.
|
||||||
|
|
||||||
|
If the `allow_expired` flag is truthy, the token is decoded even
|
||||||
|
if the expiration claim is in the past.
|
||||||
"""
|
"""
|
||||||
|
options = {}
|
||||||
|
|
||||||
|
if allow_expired:
|
||||||
|
options["verify_exp"] = False
|
||||||
|
|
||||||
token_data = jwt.decode(
|
token_data = jwt.decode(
|
||||||
token, django.conf.settings.JWT_SIGNING_SECRET, algorithms=["HS256"]
|
token,
|
||||||
|
django.conf.settings.JWT_SIGNING_SECRET,
|
||||||
|
algorithms=["HS256"],
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_data
|
return token_data
|
||||||
|
|
|
@ -27,3 +27,17 @@ def test_token_decode_fails_if_expired():
|
||||||
|
|
||||||
with pytest.raises(jwt.ExpiredSignatureError):
|
with pytest.raises(jwt.ExpiredSignatureError):
|
||||||
identity.jwt.decode_token(token)
|
identity.jwt.decode_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_decode_succeeds_if_expired_and_allow_expired_truthy():
|
||||||
|
MOCK_USER_ID = 1
|
||||||
|
|
||||||
|
with freezegun.freeze_time("2012-01-01"):
|
||||||
|
token, token_data = identity.jwt.generate_token_for_user(MOCK_USER_ID)
|
||||||
|
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
|
decoded_token = identity.jwt.decode_token(token, allow_expired=True)
|
||||||
|
|
||||||
|
assert decoded_token is not None
|
||||||
|
assert decoded_token == token_data
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
from identity.models import AuthenticationToken
|
from identity.models import AuthenticationToken
|
||||||
|
from identity.jwt import decode_token, generate_token_for_user, TokenData
|
||||||
|
|
||||||
|
|
||||||
class UnregisteredTokenException(Exception):
|
class UnregisteredTokenException(Exception):
|
||||||
|
@ -9,6 +12,14 @@ class TokenAlreadyRevokedException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCannotBeRefreshedException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRefreshTokenException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def revoke_token_by_id(token_id: str):
|
def revoke_token_by_id(token_id: str):
|
||||||
"""
|
"""
|
||||||
Revokes a token given its identifier.
|
Revokes a token given its identifier.
|
||||||
|
@ -22,10 +33,48 @@ def revoke_token_by_id(token_id: str):
|
||||||
try:
|
try:
|
||||||
token_record = AuthenticationToken.objects.get(id=token_id)
|
token_record = AuthenticationToken.objects.get(id=token_id)
|
||||||
except AuthenticationToken.DoesNotExist as e:
|
except AuthenticationToken.DoesNotExist as e:
|
||||||
raise UnregisteredTokenException("Token {token_id} not registered.") from e
|
raise UnregisteredTokenException(f"Token {token_id} not registered.") from e
|
||||||
|
|
||||||
if token_record.revoked:
|
if token_record.revoked:
|
||||||
raise TokenAlreadyRevokedException(f"Token {token_id} already revoked.")
|
raise TokenAlreadyRevokedException(f"Token {token_id} already revoked.")
|
||||||
|
|
||||||
token_record.revoked = True
|
token_record.revoked = True
|
||||||
token_record.save()
|
token_record.save()
|
||||||
|
|
||||||
|
|
||||||
|
def renew_token(token: str, refresh_token: str) -> tuple[str, TokenData, str]:
|
||||||
|
"""
|
||||||
|
Given a token (expired or not) and its refresh token, creates a new
|
||||||
|
token for the same user. The old token is invalidated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token_data = decode_token(token, allow_expired=True)
|
||||||
|
|
||||||
|
token_id = token_data["token_id"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_record = AuthenticationToken.objects.get(id=token_id)
|
||||||
|
except AuthenticationToken.DoesNotExist as e:
|
||||||
|
raise UnregisteredTokenException(f"Token {token_id} not registered.") from e
|
||||||
|
|
||||||
|
if token_record.revoked:
|
||||||
|
raise TokenCannotBeRefreshedException(
|
||||||
|
f"Token {token_id} is revoked and cannot be refreshed."
|
||||||
|
)
|
||||||
|
|
||||||
|
if refresh_token != str(token_record.refresh_token):
|
||||||
|
raise InvalidRefreshTokenException(
|
||||||
|
f"Refresh token {refresh_token} does not match records for {token_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_record.revoked = True
|
||||||
|
token_record.save()
|
||||||
|
|
||||||
|
new_token, token_data = generate_token_for_user(user_id=token_record.user_id)
|
||||||
|
new_token_record = AuthenticationToken.objects.create(
|
||||||
|
id=token_data["token_id"],
|
||||||
|
user_id=token_data["user_id"],
|
||||||
|
expires_at=datetime.datetime.fromtimestamp(token_data["exp"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_token, token_data, str(new_token_record.refresh_token)
|
||||||
|
|
111
backend/rotini/identity/token_management_test.py
Normal file
111
backend/rotini/identity/token_management_test.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from identity.token_management import (
|
||||||
|
revoke_token_by_id,
|
||||||
|
renew_token,
|
||||||
|
UnregisteredTokenException,
|
||||||
|
TokenAlreadyRevokedException,
|
||||||
|
TokenCannotBeRefreshedException,
|
||||||
|
InvalidRefreshTokenException,
|
||||||
|
)
|
||||||
|
from identity.models import AuthenticationToken
|
||||||
|
from identity.jwt import generate_token_for_user
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.djangodb
|
||||||
|
|
||||||
|
|
||||||
|
def test_revoke_token_by_id_fails_if_token_not_on_record():
|
||||||
|
with pytest.raises(UnregisteredTokenException):
|
||||||
|
revoke_token_by_id(str(uuid.uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
def test_revoke_token_by_id_fails_if_token_already_revoked(test_user):
|
||||||
|
token_record = AuthenticationToken.objects.create(
|
||||||
|
user_id=test_user.id,
|
||||||
|
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
revoked=True,
|
||||||
|
)
|
||||||
|
with pytest.raises(TokenAlreadyRevokedException):
|
||||||
|
revoke_token_by_id(str(token_record.id))
|
||||||
|
|
||||||
|
|
||||||
|
def test_revoke_token_by_id_sets_token_as_revoked(test_user):
|
||||||
|
token_record = AuthenticationToken.objects.create(
|
||||||
|
user_id=test_user.id, expires_at=datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
)
|
||||||
|
revoke_token_by_id(str(token_record.id))
|
||||||
|
|
||||||
|
token_record.refresh_from_db()
|
||||||
|
|
||||||
|
assert token_record.revoked
|
||||||
|
|
||||||
|
|
||||||
|
def test_renew_token_fails_if_not_on_record(test_user):
|
||||||
|
token, _ = generate_token_for_user(user_id=test_user.id)
|
||||||
|
with pytest.raises(UnregisteredTokenException):
|
||||||
|
renew_token(token, str(uuid.uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
def test_renew_token_fails_if_revoked(test_user):
|
||||||
|
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||||
|
|
||||||
|
token_record = AuthenticationToken.objects.create(
|
||||||
|
id=token_data["token_id"],
|
||||||
|
user_id=test_user.id,
|
||||||
|
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
revoked=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TokenCannotBeRefreshedException):
|
||||||
|
renew_token(token, str(token_record.refresh_token))
|
||||||
|
|
||||||
|
|
||||||
|
def test_renew_token_fails_if_refresh_token_wrong(test_user):
|
||||||
|
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||||
|
AuthenticationToken.objects.create(
|
||||||
|
id=token_data["token_id"],
|
||||||
|
user_id=test_user.id,
|
||||||
|
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidRefreshTokenException):
|
||||||
|
renew_token(token, str(uuid.uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
def test_renew_token_generates_a_new_token_for_the_same_user(test_user):
|
||||||
|
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||||
|
|
||||||
|
token_record = AuthenticationToken.objects.create(
|
||||||
|
id=token_data["token_id"],
|
||||||
|
user_id=test_user.id,
|
||||||
|
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, token_data, refresh_token = renew_token(token, str(token_record.refresh_token))
|
||||||
|
|
||||||
|
new_token_record = AuthenticationToken.objects.get(id=token_data["token_id"])
|
||||||
|
|
||||||
|
assert new_token_record.user_id == token_record.user_id
|
||||||
|
assert str(new_token_record.refresh_token) == refresh_token
|
||||||
|
|
||||||
|
|
||||||
|
def test_renew_token_revokes_previous_token(test_user):
|
||||||
|
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||||
|
|
||||||
|
token_record = AuthenticationToken.objects.create(
|
||||||
|
id=token_data["token_id"],
|
||||||
|
user_id=test_user.id,
|
||||||
|
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, token_data, _ = renew_token(token, str(token_record.refresh_token))
|
||||||
|
|
||||||
|
new_token_record = AuthenticationToken.objects.get(id=token_data["token_id"])
|
||||||
|
|
||||||
|
token_record.refresh_from_db()
|
||||||
|
|
||||||
|
assert token_record.revoked
|
||||||
|
assert new_token_record.revoked is False
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
from django.http import HttpResponse, JsonResponse, HttpRequest
|
from django.http import HttpResponse, JsonResponse, HttpRequest
|
||||||
import django.contrib.auth
|
import django.contrib.auth
|
||||||
|
@ -8,8 +9,9 @@ import rest_framework.status
|
||||||
|
|
||||||
import identity.jwt
|
import identity.jwt
|
||||||
from identity.models import AuthenticationToken
|
from identity.models import AuthenticationToken
|
||||||
from identity.token_management import revoke_token_by_id
|
from identity.token_management import revoke_token_by_id, renew_token
|
||||||
from identity.serializers import UserSerializer
|
from identity.serializers import UserSerializer
|
||||||
|
from identity.authentication_classes import JwtAuthenticationAllowExpired
|
||||||
|
|
||||||
AuthUser = django.contrib.auth.get_user_model()
|
AuthUser = django.contrib.auth.get_user_model()
|
||||||
|
|
||||||
|
@ -21,6 +23,8 @@ class SessionListView(rest_framework.views.APIView):
|
||||||
Views handling authenticated user sessions.
|
Views handling authenticated user sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
authentication_classes = [JwtAuthenticationAllowExpired]
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
def post(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""
|
"""
|
||||||
Handles signing in for existing users.
|
Handles signing in for existing users.
|
||||||
|
@ -69,6 +73,35 @@ class SessionListView(rest_framework.views.APIView):
|
||||||
|
|
||||||
return HttpResponse(status=401)
|
return HttpResponse(status=401)
|
||||||
|
|
||||||
|
def put(self, request: HttpRequest) -> HttpResponse:
|
||||||
|
"""
|
||||||
|
Refreshes a session using a refresh token (provided via body).
|
||||||
|
|
||||||
|
On success, returns a new authentication token via cookie and
|
||||||
|
a new refresh token via response body.
|
||||||
|
|
||||||
|
The previous auth+refresh token pair is invalidated and cannot be reused.
|
||||||
|
"""
|
||||||
|
|
||||||
|
request_body = json.loads(request.body.decode("utf-8"))
|
||||||
|
|
||||||
|
current_auth_token = request.COOKIES.get("jwt", None)
|
||||||
|
current_refresh_token = request_body.get("refresh_token", None)
|
||||||
|
if not current_auth_token or not current_refresh_token:
|
||||||
|
return HttpResponse(status=400)
|
||||||
|
|
||||||
|
new_token, _, new_refresh_token = renew_token(
|
||||||
|
current_auth_token, current_refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
response = JsonResponse({"refresh_token": new_refresh_token}, status=201)
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
"jwt", value=new_token, secure=True, domain="localhost", httponly=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def delete(self, request: HttpRequest) -> HttpResponse:
|
def delete(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""
|
"""
|
||||||
Logs out the requesting user.
|
Logs out the requesting user.
|
||||||
|
@ -79,6 +112,9 @@ class SessionListView(rest_framework.views.APIView):
|
||||||
|
|
||||||
current_token_id = request.session.get("token_id", None)
|
current_token_id = request.session.get("token_id", None)
|
||||||
|
|
||||||
|
if request.session.get("expired", False):
|
||||||
|
return HttpResponse(status=403)
|
||||||
|
|
||||||
if current_token_id is None:
|
if current_token_id is None:
|
||||||
return HttpResponse(status=400)
|
return HttpResponse(status=400)
|
||||||
|
|
||||||
|
@ -87,6 +123,26 @@ class SessionListView(rest_framework.views.APIView):
|
||||||
|
|
||||||
return HttpResponse(status=204)
|
return HttpResponse(status=204)
|
||||||
|
|
||||||
|
def get(self, request: HttpRequest) -> HttpResponse:
|
||||||
|
"""
|
||||||
|
Verifies if the current session is still valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_token_id = request.session.get("token_id", None)
|
||||||
|
|
||||||
|
token_record = AuthenticationToken.objects.get(id=current_token_id)
|
||||||
|
|
||||||
|
if (
|
||||||
|
token_record.expires_at + datetime.timedelta(minutes=5)
|
||||||
|
) < datetime.datetime.now(datetime.timezone.utc):
|
||||||
|
return HttpResponse(status=401)
|
||||||
|
|
||||||
|
should_refresh = datetime.datetime.now(
|
||||||
|
datetime.timezone.utc
|
||||||
|
) > token_record.expires_at - datetime.timedelta(minutes=8)
|
||||||
|
|
||||||
|
return JsonResponse({"should_refresh": should_refresh}, status=200)
|
||||||
|
|
||||||
|
|
||||||
class UserListView(rest_framework.views.APIView):
|
class UserListView(rest_framework.views.APIView):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import typing
|
import typing
|
||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import freezegun
|
||||||
|
|
||||||
import django.urls
|
import django.urls
|
||||||
import django.contrib.auth
|
import django.contrib.auth
|
||||||
|
@ -44,6 +46,29 @@ def fixture_logout_request(auth_client):
|
||||||
return _logout_request
|
return _logout_request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="refresh_session_request")
|
||||||
|
def fixture_refresh_session_request(auth_client):
|
||||||
|
def _refresh_session_request(refresh_token: typing.Optional[str] = None):
|
||||||
|
data = {"refresh_token": refresh_token} if refresh_token else {}
|
||||||
|
return auth_client.put(
|
||||||
|
django.urls.reverse("auth-session-list"),
|
||||||
|
json.dumps(data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return _refresh_session_request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="get_session_request")
|
||||||
|
def fixture_get_session_request(auth_client):
|
||||||
|
def _get_session_request():
|
||||||
|
return auth_client.get(
|
||||||
|
django.urls.reverse("auth-session-list"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return _get_session_request
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="get_current_user_request")
|
@pytest.fixture(name="get_current_user_request")
|
||||||
def fixture_get_current_user(auth_client):
|
def fixture_get_current_user(auth_client):
|
||||||
def _get_current_user_request(client: typing.Optional[Client] = None):
|
def _get_current_user_request(client: typing.Optional[Client] = None):
|
||||||
|
@ -130,3 +155,67 @@ def test_get_current_user_returns_403_if_unauthenticated(
|
||||||
response = get_current_user_request(no_auth_client)
|
response = get_current_user_request(no_auth_client)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_returns_401_if_token_expired_past_threshold(
|
||||||
|
test_user_credentials, login_request, get_session_request
|
||||||
|
):
|
||||||
|
with freezegun.freeze_time("2012-01-01"):
|
||||||
|
login_request(
|
||||||
|
test_user_credentials["username"], test_user_credentials["password"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = get_session_request()
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_returns_whether_token_should_be_refreshed(
|
||||||
|
login_request, get_session_request, test_user_credentials
|
||||||
|
):
|
||||||
|
login_request(test_user_credentials["username"], test_user_credentials["password"])
|
||||||
|
|
||||||
|
response = get_session_request()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"should_refresh": False}
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_refresh_rejects_with_400_if_no_refresh_token(
|
||||||
|
login_request,
|
||||||
|
refresh_session_request,
|
||||||
|
test_user_credentials,
|
||||||
|
):
|
||||||
|
login_request(test_user_credentials["username"], test_user_credentials["password"])
|
||||||
|
|
||||||
|
response = refresh_session_request()
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_refresh_rejects_with_400_if_not_authenticated(refresh_session_request):
|
||||||
|
response = refresh_session_request()
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_refresh_attaches_new_token_to_response(
|
||||||
|
login_request,
|
||||||
|
refresh_session_request,
|
||||||
|
test_user_credentials,
|
||||||
|
):
|
||||||
|
login_response = login_request(
|
||||||
|
test_user_credentials["username"], test_user_credentials["password"]
|
||||||
|
)
|
||||||
|
|
||||||
|
login_response_data = login_response.json()
|
||||||
|
|
||||||
|
refresh_response = refresh_session_request(login_response_data["refresh_token"])
|
||||||
|
|
||||||
|
refresh_response_data = refresh_response.json()
|
||||||
|
|
||||||
|
assert refresh_response.status_code == 201
|
||||||
|
assert (
|
||||||
|
refresh_response_data["refresh_token"] != login_response_data["refresh_token"]
|
||||||
|
)
|
||||||
|
assert refresh_response.cookies["jwt"].value != login_response.cookies["jwt"].value
|
||||||
|
|
|
@ -43,6 +43,6 @@ done;
|
||||||
sleep $HEALTHCHECK_SLEEP
|
sleep $HEALTHCHECK_SLEEP
|
||||||
|
|
||||||
#ROTINI_TEST=1 PYTHONPATH=rotini $VENV_PYTHON rotini/migrations/migrate.py up || fail "Migrations failed."
|
#ROTINI_TEST=1 PYTHONPATH=rotini $VENV_PYTHON rotini/migrations/migrate.py up || fail "Migrations failed."
|
||||||
$VENV_PYTEST . -vv -s || fail "Test run failed."
|
$VENV_PYTEST . -vv -s $@ || fail "Test run failed."
|
||||||
|
|
||||||
cleanup
|
cleanup
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"build": "vite build ./src --config ./vite.config.js",
|
"build": "vite build ./src --config ./vite.config.js",
|
||||||
"lint": "biome check src *.js --verbose && biome format src *.js --verbose",
|
"lint": "biome check src *.js --verbose && biome format src *.js --verbose",
|
||||||
"lint:fix": "biome check src ./*.js --apply --verbose && biome format src ./*.js --write --verbose",
|
"lint:fix": "biome check src ./*.js --apply --verbose && biome format src ./*.js --write --verbose",
|
||||||
"test": "yarn vitest run",
|
"test": "yarn vitest --watch=false",
|
||||||
"typecheck": "yarn tsc --noEmit"
|
"typecheck": "yarn tsc --noEmit"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import React from "react"
|
||||||
|
|
||||||
import { Box } from "@mui/material"
|
import { Box } from "@mui/material"
|
||||||
import {
|
import {
|
||||||
QueryClient,
|
QueryClient,
|
||||||
|
@ -11,6 +13,8 @@ import LocationContext from "@/contexts/LocationContext"
|
||||||
|
|
||||||
import { Router, Route } from "@/router"
|
import { Router, Route } from "@/router"
|
||||||
|
|
||||||
|
import setupAuthTokenAutoRefresh from "@/authRefresh"
|
||||||
|
|
||||||
import FileListView from "@/components/FileListView"
|
import FileListView from "@/components/FileListView"
|
||||||
import RegisterView from "@/components/RegisterView"
|
import RegisterView from "@/components/RegisterView"
|
||||||
import LoginView from "@/components/LoginView"
|
import LoginView from "@/components/LoginView"
|
||||||
|
@ -25,6 +29,14 @@ const routes = {
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = () => {
|
const App = () => {
|
||||||
|
React.useEffect(() => {
|
||||||
|
const stopAutoRefresh = setupAuthTokenAutoRefresh()
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
stopAutoRefresh?.()
|
||||||
|
}
|
||||||
|
}, [])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box sx={{ display: "flex", flexDirection: "column", width: "100%" }}>
|
<Box sx={{ display: "flex", flexDirection: "column", width: "100%" }}>
|
||||||
<NavigationBar />
|
<NavigationBar />
|
||||||
|
|
84
frontend/src/authRefresh.test.ts
Normal file
84
frontend/src/authRefresh.test.ts
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import { describe, it, vi, expect, beforeEach, afterEach } from "vitest"
|
||||||
|
import { waitFor } from "@testing-library/react"
|
||||||
|
import { getAxiosMockAdapter } from "@/tests/helpers"
|
||||||
|
import setupAuthTokenAutoRefresh, {
|
||||||
|
REFRESH_INTERVAL,
|
||||||
|
REFRESH_TOKEN_KEY,
|
||||||
|
} from "./authRefresh"
|
||||||
|
|
||||||
|
async function flushPromises() {
|
||||||
|
return new Promise((r) => setTimeout(r))
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("setupAuthTokenAutoRefresh", () => {
|
||||||
|
let clearInterval: (() => void) | undefined = undefined
|
||||||
|
const mockLocalStorage = new Map<string, string>()
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.spyOn(Storage.prototype, "setItem").mockImplementation(
|
||||||
|
(key: string, value: string) => {
|
||||||
|
mockLocalStorage.set(key, value)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vi.spyOn(Storage.prototype, "getItem").mockImplementation(
|
||||||
|
(key: string): string => mockLocalStorage.get(key) || "",
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.useFakeTimers()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.restoreAllMocks()
|
||||||
|
globalThis.localStorage.clear()
|
||||||
|
|
||||||
|
if (clearInterval) clearInterval()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("does not do anything if loop already initialized", () => {
|
||||||
|
const clear = setupAuthTokenAutoRefresh()
|
||||||
|
|
||||||
|
expect(clear).not.toBeUndefined()
|
||||||
|
const secondClear = setupAuthTokenAutoRefresh()
|
||||||
|
|
||||||
|
expect(secondClear).toBeUndefined()
|
||||||
|
|
||||||
|
clear?.()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("checks the status of the auth token at each interval", async () => {
|
||||||
|
const axiosMock = getAxiosMockAdapter()
|
||||||
|
|
||||||
|
globalThis.localStorage.setItem(REFRESH_TOKEN_KEY, "refresh_token")
|
||||||
|
|
||||||
|
axiosMock.onGet("/auth/session/").reply(200, { should_refresh: false })
|
||||||
|
clearInterval = setupAuthTokenAutoRefresh()
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(REFRESH_INTERVAL)
|
||||||
|
|
||||||
|
expect(axiosMock.history.get.length).toEqual(1)
|
||||||
|
|
||||||
|
const apiCall = axiosMock.history.get[0]
|
||||||
|
|
||||||
|
expect(apiCall.url).toEqual("/auth/session/")
|
||||||
|
})
|
||||||
|
|
||||||
|
// FIXME: LocalStorage does not update, but the request is run as expected.
|
||||||
|
it.skip("attempts to refresh the token if the status checks says the token must be refreshed", async () => {
|
||||||
|
const axiosMock = getAxiosMockAdapter()
|
||||||
|
|
||||||
|
globalThis.localStorage.setItem(REFRESH_TOKEN_KEY, "refresh_token")
|
||||||
|
|
||||||
|
axiosMock
|
||||||
|
.onGet("/auth/session/")
|
||||||
|
.reply(200, { should_refresh: true })
|
||||||
|
.onPut("/auth/session/")
|
||||||
|
.reply(201, { refresh_token: "new_refresh_token" })
|
||||||
|
|
||||||
|
clearInterval = setupAuthTokenAutoRefresh()
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(REFRESH_INTERVAL)
|
||||||
|
|
||||||
|
const localStorageToken = globalThis.localStorage.getItem(REFRESH_TOKEN_KEY)
|
||||||
|
|
||||||
|
expect(localStorageToken).toEqual("new_refresh_token")
|
||||||
|
})
|
||||||
|
})
|
70
frontend/src/authRefresh.ts
Normal file
70
frontend/src/authRefresh.ts
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import axiosWithDefaults from "@/axios"
|
||||||
|
|
||||||
|
const REFRESH_TOKEN_KEY = "jwt_refresh_token"
|
||||||
|
const REFRESH_INTERVAL = 30000
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Singleton tracking the token refresh loop.
|
||||||
|
*
|
||||||
|
* This ensures that token autorefresh is only initialized
|
||||||
|
* once.
|
||||||
|
*/
|
||||||
|
let tokenRefreshLoop: ReturnType<typeof setInterval> | null = null
|
||||||
|
|
||||||
|
function getRefreshToken() {
|
||||||
|
return globalThis.localStorage.getItem(REFRESH_TOKEN_KEY)
|
||||||
|
}
|
||||||
|
|
||||||
|
function setRefreshToken(token: string) {
|
||||||
|
globalThis.localStorage.setItem(REFRESH_TOKEN_KEY, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
function unsetRefreshToken() {
|
||||||
|
globalThis.localStorage.removeItem(REFRESH_TOKEN_KEY)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Periodically verifies the status of the authentication token
|
||||||
|
* expiration and initiates token refresh based on feedback
|
||||||
|
* from the server.
|
||||||
|
*
|
||||||
|
* If the token is in need of refresh, a request is made to
|
||||||
|
* generate a new token that will be attached to the response
|
||||||
|
* as a cookie, and the refresh token stored in localStorage
|
||||||
|
* gets refreshed.
|
||||||
|
*/
|
||||||
|
function setupAuthTokenAutoRefresh() {
|
||||||
|
if (tokenRefreshLoop) {
|
||||||
|
console.warn("Authentication token refresh loop already initialized.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenRefreshLoop = setInterval(async () => {
|
||||||
|
const fetchResponse = await axiosWithDefaults.get("/auth/session/")
|
||||||
|
|
||||||
|
const shouldRefresh = fetchResponse.data.should_refresh
|
||||||
|
|
||||||
|
if (!shouldRefresh) return
|
||||||
|
|
||||||
|
const refreshResponse = await axiosWithDefaults.put("/auth/session/", {
|
||||||
|
refresh_token: getRefreshToken(),
|
||||||
|
})
|
||||||
|
|
||||||
|
setRefreshToken(refreshResponse.data.refresh_token)
|
||||||
|
}, REFRESH_INTERVAL)
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
if (tokenRefreshLoop) clearInterval(tokenRefreshLoop)
|
||||||
|
tokenRefreshLoop = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export {
|
||||||
|
setRefreshToken,
|
||||||
|
getRefreshToken,
|
||||||
|
unsetRefreshToken,
|
||||||
|
REFRESH_INTERVAL,
|
||||||
|
REFRESH_TOKEN_KEY,
|
||||||
|
}
|
||||||
|
|
||||||
|
export default setupAuthTokenAutoRefresh
|
|
@ -129,7 +129,9 @@ describe("LoginView", () => {
|
||||||
}))
|
}))
|
||||||
const axiosMockAdapter = new AxiosMockAdapter(axios)
|
const axiosMockAdapter = new AxiosMockAdapter(axios)
|
||||||
|
|
||||||
axiosMockAdapter.onPost("/auth/session/").reply(201)
|
axiosMockAdapter
|
||||||
|
.onPost("/auth/session/")
|
||||||
|
.reply(201, { refresh_token: "notatoken" })
|
||||||
|
|
||||||
const { user } = renderComponent()
|
const { user } = renderComponent()
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
*/
|
*/
|
||||||
import { useQueryClient, useMutation } from "@tanstack/react-query"
|
import { useQueryClient, useMutation } from "@tanstack/react-query"
|
||||||
|
|
||||||
|
import { setRefreshToken, unsetRefreshToken } from "@/authRefresh"
|
||||||
import { useLocationContext } from "@/contexts/LocationContext"
|
import { useLocationContext } from "@/contexts/LocationContext"
|
||||||
import axiosWithDefaults from "@/axios"
|
import axiosWithDefaults from "@/axios"
|
||||||
|
|
||||||
|
@ -29,6 +30,7 @@ function useLogout() {
|
||||||
},
|
},
|
||||||
onSuccess: async () => {
|
onSuccess: async () => {
|
||||||
await queryClient.invalidateQueries({ queryKey: ["current-user"] })
|
await queryClient.invalidateQueries({ queryKey: ["current-user"] })
|
||||||
|
unsetRefreshToken()
|
||||||
navigate("/login")
|
navigate("/login")
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -61,7 +63,8 @@ function useLogin() {
|
||||||
password,
|
password,
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
onSuccess: async () => {
|
onSuccess: async (response) => {
|
||||||
|
setRefreshToken(response.data.refresh_token)
|
||||||
await queryClient.refetchQueries({ queryKey: ["current-user"] })
|
await queryClient.refetchQueries({ queryKey: ["current-user"] })
|
||||||
navigate("/")
|
navigate("/")
|
||||||
},
|
},
|
||||||
|
|
Reference in a new issue