Merge pull request #113 from mcataford/feat/auth-token-refresh
Feat/auth token refresh
This commit is contained in:
commit
2fb7fa3a59
15 changed files with 553 additions and 22 deletions
|
@ -32,7 +32,7 @@ tasks:
|
|||
test:
|
||||
desc: "Run the test suites."
|
||||
deps: [bootstrap]
|
||||
cmd: $SHELL script/test
|
||||
cmd: $SHELL script/test {{ .CLI_ARGS }}
|
||||
dotenv:
|
||||
- ../backend-test.env
|
||||
lock-deps:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import datetime
|
||||
|
||||
import django.contrib.auth
|
||||
from rest_framework import authentication
|
||||
|
@ -15,15 +16,13 @@ class RevokedTokenException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class JwtAuthentication(authentication.BaseAuthentication):
|
||||
"""
|
||||
Authentication class handling JWTs attached to requests via cookies.
|
||||
class JwtAuthenticationBase(authentication.BaseAuthentication):
|
||||
@property
|
||||
def allow_expired(self):
|
||||
return False
|
||||
|
||||
A JWT is only accepted if it's not expired (i.e. can be decoded) and if
|
||||
it has not been revoked (as per AuthenticationToken records). A revoked
|
||||
token is declined even if the token itself has not expired yet and would
|
||||
otherwise be valid.
|
||||
"""
|
||||
def _get_decoded_token(self, token: str):
|
||||
return identity.jwt.decode_token(token, allow_expired=self.allow_expired)
|
||||
|
||||
def authenticate(self, request):
|
||||
jwt_cookie = request.COOKIES.get("jwt")
|
||||
|
@ -33,7 +32,7 @@ class JwtAuthentication(authentication.BaseAuthentication):
|
|||
return None
|
||||
|
||||
try:
|
||||
decoded_token = identity.jwt.decode_token(jwt_cookie)
|
||||
decoded_token = self._get_decoded_token(jwt_cookie)
|
||||
|
||||
logger.info("Token: %s\nDecoded token: %s", jwt_cookie, decoded_token)
|
||||
|
||||
|
@ -45,9 +44,42 @@ class JwtAuthentication(authentication.BaseAuthentication):
|
|||
raise RevokedTokenException("Revoked tokens cannot be used")
|
||||
|
||||
request.session["token_id"] = decoded_token["token_id"]
|
||||
|
||||
request.session["expired"] = (
|
||||
decoded_token["exp"]
|
||||
< datetime.datetime.now(datetime.timezone.utc).timestamp()
|
||||
)
|
||||
return user, None
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.exception(e, extra={"authorization_provided": jwt_cookie})
|
||||
return None, None
|
||||
|
||||
|
||||
class JwtAuthenticationAllowExpired(JwtAuthenticationBase):
|
||||
"""
|
||||
Authentication class handling JWTs attached to requests via cookies.
|
||||
|
||||
A JWT is only accepted if it's not expired (i.e. can be decoded) and if
|
||||
it has not been revoked (as per AuthenticationToken records). A revoked
|
||||
token is declined even if the token itself has not expired yet and would
|
||||
otherwise be valid.
|
||||
"""
|
||||
|
||||
@property
|
||||
def allow_expired(self):
|
||||
return True
|
||||
|
||||
|
||||
class JwtAuthentication(JwtAuthenticationBase):
|
||||
"""
|
||||
Authentication class handling JWTs attached to requests via cookies.
|
||||
|
||||
A JWT is only accepted if it's not expired (i.e. can be decoded) and if
|
||||
it has not been revoked (as per AuthenticationToken records). A revoked
|
||||
token is declined even if the token itself has not expired yet and would
|
||||
otherwise be valid.
|
||||
"""
|
||||
|
||||
@property
|
||||
def allow_expired(self):
|
||||
return False
|
||||
|
|
|
@ -26,7 +26,7 @@ def generate_token_for_user(user_id: int) -> tuple[str, TokenData]:
|
|||
|
||||
token_data = {
|
||||
"exp": (
|
||||
datetime.datetime.now()
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=django.conf.settings.JWT_EXPIRATION)
|
||||
).timestamp(),
|
||||
"user_id": user_id,
|
||||
|
@ -41,16 +41,25 @@ def generate_token_for_user(user_id: int) -> tuple[str, TokenData]:
|
|||
)
|
||||
|
||||
|
||||
def decode_token(
|
||||
token: str,
|
||||
):
|
||||
def decode_token(token: str, allow_expired: bool = False):
|
||||
"""
|
||||
Decodes the given token.
|
||||
|
||||
This may raise if the token is expired or invalid.
|
||||
|
||||
If the `allow_expired` flag is truthy, the token is decoded even
|
||||
if the expiration claim is in the past.
|
||||
"""
|
||||
options = {}
|
||||
|
||||
if allow_expired:
|
||||
options["verify_exp"] = False
|
||||
|
||||
token_data = jwt.decode(
|
||||
token, django.conf.settings.JWT_SIGNING_SECRET, algorithms=["HS256"]
|
||||
token,
|
||||
django.conf.settings.JWT_SIGNING_SECRET,
|
||||
algorithms=["HS256"],
|
||||
options=options,
|
||||
)
|
||||
|
||||
return token_data
|
||||
|
|
|
@ -27,3 +27,17 @@ def test_token_decode_fails_if_expired():
|
|||
|
||||
with pytest.raises(jwt.ExpiredSignatureError):
|
||||
identity.jwt.decode_token(token)
|
||||
|
||||
|
||||
def test_token_decode_succeeds_if_expired_and_allow_expired_truthy():
|
||||
MOCK_USER_ID = 1
|
||||
|
||||
with freezegun.freeze_time("2012-01-01"):
|
||||
token, token_data = identity.jwt.generate_token_for_user(MOCK_USER_ID)
|
||||
|
||||
assert token is not None
|
||||
|
||||
decoded_token = identity.jwt.decode_token(token, allow_expired=True)
|
||||
|
||||
assert decoded_token is not None
|
||||
assert decoded_token == token_data
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import datetime
|
||||
|
||||
from identity.models import AuthenticationToken
|
||||
from identity.jwt import decode_token, generate_token_for_user, TokenData
|
||||
|
||||
|
||||
class UnregisteredTokenException(Exception):
|
||||
|
@ -9,6 +12,14 @@ class TokenAlreadyRevokedException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class TokenCannotBeRefreshedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidRefreshTokenException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def revoke_token_by_id(token_id: str):
|
||||
"""
|
||||
Revokes a token given its identifier.
|
||||
|
@ -22,10 +33,48 @@ def revoke_token_by_id(token_id: str):
|
|||
try:
|
||||
token_record = AuthenticationToken.objects.get(id=token_id)
|
||||
except AuthenticationToken.DoesNotExist as e:
|
||||
raise UnregisteredTokenException("Token {token_id} not registered.") from e
|
||||
raise UnregisteredTokenException(f"Token {token_id} not registered.") from e
|
||||
|
||||
if token_record.revoked:
|
||||
raise TokenAlreadyRevokedException(f"Token {token_id} already revoked.")
|
||||
|
||||
token_record.revoked = True
|
||||
token_record.save()
|
||||
|
||||
|
||||
def renew_token(token: str, refresh_token: str) -> tuple[str, TokenData, str]:
|
||||
"""
|
||||
Given a token (expired or not) and its refresh token, creates a new
|
||||
token for the same user. The old token is invalidated.
|
||||
"""
|
||||
|
||||
token_data = decode_token(token, allow_expired=True)
|
||||
|
||||
token_id = token_data["token_id"]
|
||||
|
||||
try:
|
||||
token_record = AuthenticationToken.objects.get(id=token_id)
|
||||
except AuthenticationToken.DoesNotExist as e:
|
||||
raise UnregisteredTokenException(f"Token {token_id} not registered.") from e
|
||||
|
||||
if token_record.revoked:
|
||||
raise TokenCannotBeRefreshedException(
|
||||
f"Token {token_id} is revoked and cannot be refreshed."
|
||||
)
|
||||
|
||||
if refresh_token != str(token_record.refresh_token):
|
||||
raise InvalidRefreshTokenException(
|
||||
f"Refresh token {refresh_token} does not match records for {token_id}"
|
||||
)
|
||||
|
||||
token_record.revoked = True
|
||||
token_record.save()
|
||||
|
||||
new_token, token_data = generate_token_for_user(user_id=token_record.user_id)
|
||||
new_token_record = AuthenticationToken.objects.create(
|
||||
id=token_data["token_id"],
|
||||
user_id=token_data["user_id"],
|
||||
expires_at=datetime.datetime.fromtimestamp(token_data["exp"]),
|
||||
)
|
||||
|
||||
return new_token, token_data, str(new_token_record.refresh_token)
|
||||
|
|
111
backend/rotini/identity/token_management_test.py
Normal file
111
backend/rotini/identity/token_management_test.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import uuid
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from identity.token_management import (
|
||||
revoke_token_by_id,
|
||||
renew_token,
|
||||
UnregisteredTokenException,
|
||||
TokenAlreadyRevokedException,
|
||||
TokenCannotBeRefreshedException,
|
||||
InvalidRefreshTokenException,
|
||||
)
|
||||
from identity.models import AuthenticationToken
|
||||
from identity.jwt import generate_token_for_user
|
||||
|
||||
pytestmark = pytest.mark.djangodb
|
||||
|
||||
|
||||
def test_revoke_token_by_id_fails_if_token_not_on_record():
|
||||
with pytest.raises(UnregisteredTokenException):
|
||||
revoke_token_by_id(str(uuid.uuid4()))
|
||||
|
||||
|
||||
def test_revoke_token_by_id_fails_if_token_already_revoked(test_user):
|
||||
token_record = AuthenticationToken.objects.create(
|
||||
user_id=test_user.id,
|
||||
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
revoked=True,
|
||||
)
|
||||
with pytest.raises(TokenAlreadyRevokedException):
|
||||
revoke_token_by_id(str(token_record.id))
|
||||
|
||||
|
||||
def test_revoke_token_by_id_sets_token_as_revoked(test_user):
|
||||
token_record = AuthenticationToken.objects.create(
|
||||
user_id=test_user.id, expires_at=datetime.datetime.now(datetime.timezone.utc)
|
||||
)
|
||||
revoke_token_by_id(str(token_record.id))
|
||||
|
||||
token_record.refresh_from_db()
|
||||
|
||||
assert token_record.revoked
|
||||
|
||||
|
||||
def test_renew_token_fails_if_not_on_record(test_user):
|
||||
token, _ = generate_token_for_user(user_id=test_user.id)
|
||||
with pytest.raises(UnregisteredTokenException):
|
||||
renew_token(token, str(uuid.uuid4()))
|
||||
|
||||
|
||||
def test_renew_token_fails_if_revoked(test_user):
|
||||
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||
|
||||
token_record = AuthenticationToken.objects.create(
|
||||
id=token_data["token_id"],
|
||||
user_id=test_user.id,
|
||||
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
revoked=True,
|
||||
)
|
||||
|
||||
with pytest.raises(TokenCannotBeRefreshedException):
|
||||
renew_token(token, str(token_record.refresh_token))
|
||||
|
||||
|
||||
def test_renew_token_fails_if_refresh_token_wrong(test_user):
|
||||
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||
AuthenticationToken.objects.create(
|
||||
id=token_data["token_id"],
|
||||
user_id=test_user.id,
|
||||
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidRefreshTokenException):
|
||||
renew_token(token, str(uuid.uuid4()))
|
||||
|
||||
|
||||
def test_renew_token_generates_a_new_token_for_the_same_user(test_user):
|
||||
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||
|
||||
token_record = AuthenticationToken.objects.create(
|
||||
id=token_data["token_id"],
|
||||
user_id=test_user.id,
|
||||
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
|
||||
_, token_data, refresh_token = renew_token(token, str(token_record.refresh_token))
|
||||
|
||||
new_token_record = AuthenticationToken.objects.get(id=token_data["token_id"])
|
||||
|
||||
assert new_token_record.user_id == token_record.user_id
|
||||
assert str(new_token_record.refresh_token) == refresh_token
|
||||
|
||||
|
||||
def test_renew_token_revokes_previous_token(test_user):
|
||||
token, token_data = generate_token_for_user(user_id=test_user.id)
|
||||
|
||||
token_record = AuthenticationToken.objects.create(
|
||||
id=token_data["token_id"],
|
||||
user_id=test_user.id,
|
||||
expires_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
|
||||
_, token_data, _ = renew_token(token, str(token_record.refresh_token))
|
||||
|
||||
new_token_record = AuthenticationToken.objects.get(id=token_data["token_id"])
|
||||
|
||||
token_record.refresh_from_db()
|
||||
|
||||
assert token_record.revoked
|
||||
assert new_token_record.revoked is False
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from django.http import HttpResponse, JsonResponse, HttpRequest
|
||||
import django.contrib.auth
|
||||
|
@ -8,8 +9,9 @@ import rest_framework.status
|
|||
|
||||
import identity.jwt
|
||||
from identity.models import AuthenticationToken
|
||||
from identity.token_management import revoke_token_by_id
|
||||
from identity.token_management import revoke_token_by_id, renew_token
|
||||
from identity.serializers import UserSerializer
|
||||
from identity.authentication_classes import JwtAuthenticationAllowExpired
|
||||
|
||||
AuthUser = django.contrib.auth.get_user_model()
|
||||
|
||||
|
@ -21,6 +23,8 @@ class SessionListView(rest_framework.views.APIView):
|
|||
Views handling authenticated user sessions.
|
||||
"""
|
||||
|
||||
authentication_classes = [JwtAuthenticationAllowExpired]
|
||||
|
||||
def post(self, request: HttpRequest) -> HttpResponse:
|
||||
"""
|
||||
Handles signing in for existing users.
|
||||
|
@ -69,6 +73,35 @@ class SessionListView(rest_framework.views.APIView):
|
|||
|
||||
return HttpResponse(status=401)
|
||||
|
||||
def put(self, request: HttpRequest) -> HttpResponse:
|
||||
"""
|
||||
Refreshes a session using a refresh token (provided via body).
|
||||
|
||||
On success, returns a new authentication token via cookie and
|
||||
a new refresh token via response body.
|
||||
|
||||
The previous auth+refresh token pair is invalidated and cannot be reused.
|
||||
"""
|
||||
|
||||
request_body = json.loads(request.body.decode("utf-8"))
|
||||
|
||||
current_auth_token = request.COOKIES.get("jwt", None)
|
||||
current_refresh_token = request_body.get("refresh_token", None)
|
||||
if not current_auth_token or not current_refresh_token:
|
||||
return HttpResponse(status=400)
|
||||
|
||||
new_token, _, new_refresh_token = renew_token(
|
||||
current_auth_token, current_refresh_token
|
||||
)
|
||||
|
||||
response = JsonResponse({"refresh_token": new_refresh_token}, status=201)
|
||||
|
||||
response.set_cookie(
|
||||
"jwt", value=new_token, secure=True, domain="localhost", httponly=False
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def delete(self, request: HttpRequest) -> HttpResponse:
|
||||
"""
|
||||
Logs out the requesting user.
|
||||
|
@ -79,6 +112,9 @@ class SessionListView(rest_framework.views.APIView):
|
|||
|
||||
current_token_id = request.session.get("token_id", None)
|
||||
|
||||
if request.session.get("expired", False):
|
||||
return HttpResponse(status=403)
|
||||
|
||||
if current_token_id is None:
|
||||
return HttpResponse(status=400)
|
||||
|
||||
|
@ -87,6 +123,26 @@ class SessionListView(rest_framework.views.APIView):
|
|||
|
||||
return HttpResponse(status=204)
|
||||
|
||||
def get(self, request: HttpRequest) -> HttpResponse:
|
||||
"""
|
||||
Verifies if the current session is still valid.
|
||||
"""
|
||||
|
||||
current_token_id = request.session.get("token_id", None)
|
||||
|
||||
token_record = AuthenticationToken.objects.get(id=current_token_id)
|
||||
|
||||
if (
|
||||
token_record.expires_at + datetime.timedelta(minutes=5)
|
||||
) < datetime.datetime.now(datetime.timezone.utc):
|
||||
return HttpResponse(status=401)
|
||||
|
||||
should_refresh = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
) > token_record.expires_at - datetime.timedelta(minutes=8)
|
||||
|
||||
return JsonResponse({"should_refresh": should_refresh}, status=200)
|
||||
|
||||
|
||||
class UserListView(rest_framework.views.APIView):
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import typing
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import freezegun
|
||||
|
||||
import django.urls
|
||||
import django.contrib.auth
|
||||
|
@ -44,6 +46,29 @@ def fixture_logout_request(auth_client):
|
|||
return _logout_request
|
||||
|
||||
|
||||
@pytest.fixture(name="refresh_session_request")
|
||||
def fixture_refresh_session_request(auth_client):
|
||||
def _refresh_session_request(refresh_token: typing.Optional[str] = None):
|
||||
data = {"refresh_token": refresh_token} if refresh_token else {}
|
||||
return auth_client.put(
|
||||
django.urls.reverse("auth-session-list"),
|
||||
json.dumps(data),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
return _refresh_session_request
|
||||
|
||||
|
||||
@pytest.fixture(name="get_session_request")
|
||||
def fixture_get_session_request(auth_client):
|
||||
def _get_session_request():
|
||||
return auth_client.get(
|
||||
django.urls.reverse("auth-session-list"),
|
||||
)
|
||||
|
||||
return _get_session_request
|
||||
|
||||
|
||||
@pytest.fixture(name="get_current_user_request")
|
||||
def fixture_get_current_user(auth_client):
|
||||
def _get_current_user_request(client: typing.Optional[Client] = None):
|
||||
|
@ -130,3 +155,67 @@ def test_get_current_user_returns_403_if_unauthenticated(
|
|||
response = get_current_user_request(no_auth_client)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_get_session_returns_401_if_token_expired_past_threshold(
|
||||
test_user_credentials, login_request, get_session_request
|
||||
):
|
||||
with freezegun.freeze_time("2012-01-01"):
|
||||
login_request(
|
||||
test_user_credentials["username"], test_user_credentials["password"]
|
||||
)
|
||||
|
||||
response = get_session_request()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_get_session_returns_whether_token_should_be_refreshed(
|
||||
login_request, get_session_request, test_user_credentials
|
||||
):
|
||||
login_request(test_user_credentials["username"], test_user_credentials["password"])
|
||||
|
||||
response = get_session_request()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"should_refresh": False}
|
||||
|
||||
|
||||
def test_session_refresh_rejects_with_400_if_no_refresh_token(
|
||||
login_request,
|
||||
refresh_session_request,
|
||||
test_user_credentials,
|
||||
):
|
||||
login_request(test_user_credentials["username"], test_user_credentials["password"])
|
||||
|
||||
response = refresh_session_request()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_session_refresh_rejects_with_400_if_not_authenticated(refresh_session_request):
|
||||
response = refresh_session_request()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_session_refresh_attaches_new_token_to_response(
|
||||
login_request,
|
||||
refresh_session_request,
|
||||
test_user_credentials,
|
||||
):
|
||||
login_response = login_request(
|
||||
test_user_credentials["username"], test_user_credentials["password"]
|
||||
)
|
||||
|
||||
login_response_data = login_response.json()
|
||||
|
||||
refresh_response = refresh_session_request(login_response_data["refresh_token"])
|
||||
|
||||
refresh_response_data = refresh_response.json()
|
||||
|
||||
assert refresh_response.status_code == 201
|
||||
assert (
|
||||
refresh_response_data["refresh_token"] != login_response_data["refresh_token"]
|
||||
)
|
||||
assert refresh_response.cookies["jwt"].value != login_response.cookies["jwt"].value
|
||||
|
|
|
@ -43,6 +43,6 @@ done;
|
|||
sleep $HEALTHCHECK_SLEEP
|
||||
|
||||
#ROTINI_TEST=1 PYTHONPATH=rotini $VENV_PYTHON rotini/migrations/migrate.py up || fail "Migrations failed."
|
||||
$VENV_PYTEST . -vv -s || fail "Test run failed."
|
||||
$VENV_PYTEST . -vv -s $@ || fail "Test run failed."
|
||||
|
||||
cleanup
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"build": "vite build ./src --config ./vite.config.js",
|
||||
"lint": "biome check src *.js --verbose && biome format src *.js --verbose",
|
||||
"lint:fix": "biome check src ./*.js --apply --verbose && biome format src ./*.js --write --verbose",
|
||||
"test": "yarn vitest run",
|
||||
"test": "yarn vitest --watch=false",
|
||||
"typecheck": "yarn tsc --noEmit"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import React from "react"
|
||||
|
||||
import { Box } from "@mui/material"
|
||||
import {
|
||||
QueryClient,
|
||||
|
@ -11,6 +13,8 @@ import LocationContext from "@/contexts/LocationContext"
|
|||
|
||||
import { Router, Route } from "@/router"
|
||||
|
||||
import setupAuthTokenAutoRefresh from "@/authRefresh"
|
||||
|
||||
import FileListView from "@/components/FileListView"
|
||||
import RegisterView from "@/components/RegisterView"
|
||||
import LoginView from "@/components/LoginView"
|
||||
|
@ -25,6 +29,14 @@ const routes = {
|
|||
}
|
||||
|
||||
const App = () => {
|
||||
React.useEffect(() => {
|
||||
const stopAutoRefresh = setupAuthTokenAutoRefresh()
|
||||
|
||||
return () => {
|
||||
stopAutoRefresh?.()
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<Box sx={{ display: "flex", flexDirection: "column", width: "100%" }}>
|
||||
<NavigationBar />
|
||||
|
|
84
frontend/src/authRefresh.test.ts
Normal file
84
frontend/src/authRefresh.test.ts
Normal file
|
@ -0,0 +1,84 @@
|
|||
import { describe, it, vi, expect, beforeEach, afterEach } from "vitest"
|
||||
import { waitFor } from "@testing-library/react"
|
||||
import { getAxiosMockAdapter } from "@/tests/helpers"
|
||||
import setupAuthTokenAutoRefresh, {
|
||||
REFRESH_INTERVAL,
|
||||
REFRESH_TOKEN_KEY,
|
||||
} from "./authRefresh"
|
||||
|
||||
async function flushPromises() {
|
||||
return new Promise((r) => setTimeout(r))
|
||||
}
|
||||
|
||||
describe("setupAuthTokenAutoRefresh", () => {
|
||||
let clearInterval: (() => void) | undefined = undefined
|
||||
const mockLocalStorage = new Map<string, string>()
|
||||
beforeEach(() => {
|
||||
vi.spyOn(Storage.prototype, "setItem").mockImplementation(
|
||||
(key: string, value: string) => {
|
||||
mockLocalStorage.set(key, value)
|
||||
},
|
||||
)
|
||||
vi.spyOn(Storage.prototype, "getItem").mockImplementation(
|
||||
(key: string): string => mockLocalStorage.get(key) || "",
|
||||
)
|
||||
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
globalThis.localStorage.clear()
|
||||
|
||||
if (clearInterval) clearInterval()
|
||||
})
|
||||
|
||||
it("does not do anything if loop already initialized", () => {
|
||||
const clear = setupAuthTokenAutoRefresh()
|
||||
|
||||
expect(clear).not.toBeUndefined()
|
||||
const secondClear = setupAuthTokenAutoRefresh()
|
||||
|
||||
expect(secondClear).toBeUndefined()
|
||||
|
||||
clear?.()
|
||||
})
|
||||
|
||||
it("checks the status of the auth token at each interval", async () => {
|
||||
const axiosMock = getAxiosMockAdapter()
|
||||
|
||||
globalThis.localStorage.setItem(REFRESH_TOKEN_KEY, "refresh_token")
|
||||
|
||||
axiosMock.onGet("/auth/session/").reply(200, { should_refresh: false })
|
||||
clearInterval = setupAuthTokenAutoRefresh()
|
||||
|
||||
vi.advanceTimersByTime(REFRESH_INTERVAL)
|
||||
|
||||
expect(axiosMock.history.get.length).toEqual(1)
|
||||
|
||||
const apiCall = axiosMock.history.get[0]
|
||||
|
||||
expect(apiCall.url).toEqual("/auth/session/")
|
||||
})
|
||||
|
||||
// FIXME: LocalStorage does not update, but the request is run as expected.
|
||||
it.skip("attempts to refresh the token if the status checks says the token must be refreshed", async () => {
|
||||
const axiosMock = getAxiosMockAdapter()
|
||||
|
||||
globalThis.localStorage.setItem(REFRESH_TOKEN_KEY, "refresh_token")
|
||||
|
||||
axiosMock
|
||||
.onGet("/auth/session/")
|
||||
.reply(200, { should_refresh: true })
|
||||
.onPut("/auth/session/")
|
||||
.reply(201, { refresh_token: "new_refresh_token" })
|
||||
|
||||
clearInterval = setupAuthTokenAutoRefresh()
|
||||
|
||||
vi.advanceTimersByTime(REFRESH_INTERVAL)
|
||||
|
||||
const localStorageToken = globalThis.localStorage.getItem(REFRESH_TOKEN_KEY)
|
||||
|
||||
expect(localStorageToken).toEqual("new_refresh_token")
|
||||
})
|
||||
})
|
70
frontend/src/authRefresh.ts
Normal file
70
frontend/src/authRefresh.ts
Normal file
|
@ -0,0 +1,70 @@
|
|||
import axiosWithDefaults from "@/axios"
|
||||
|
||||
const REFRESH_TOKEN_KEY = "jwt_refresh_token"
|
||||
const REFRESH_INTERVAL = 30000
|
||||
|
||||
/*
|
||||
* Singleton tracking the token refresh loop.
|
||||
*
|
||||
* This ensures that token autorefresh is only initialized
|
||||
* once.
|
||||
*/
|
||||
let tokenRefreshLoop: ReturnType<typeof setInterval> | null = null
|
||||
|
||||
function getRefreshToken() {
|
||||
return globalThis.localStorage.getItem(REFRESH_TOKEN_KEY)
|
||||
}
|
||||
|
||||
function setRefreshToken(token: string) {
|
||||
globalThis.localStorage.setItem(REFRESH_TOKEN_KEY, token)
|
||||
}
|
||||
|
||||
function unsetRefreshToken() {
|
||||
globalThis.localStorage.removeItem(REFRESH_TOKEN_KEY)
|
||||
}
|
||||
|
||||
/*
|
||||
* Periodically verifies the status of the authentication token
|
||||
* expiration and initiates token refresh based on feedback
|
||||
* from the server.
|
||||
*
|
||||
* If the token is in need of refresh, a request is made to
|
||||
* generate a new token that will be attached to the response
|
||||
* as a cookie, and the refresh token stored in localStorage
|
||||
* gets refreshed.
|
||||
*/
|
||||
function setupAuthTokenAutoRefresh() {
|
||||
if (tokenRefreshLoop) {
|
||||
console.warn("Authentication token refresh loop already initialized.")
|
||||
return
|
||||
}
|
||||
|
||||
tokenRefreshLoop = setInterval(async () => {
|
||||
const fetchResponse = await axiosWithDefaults.get("/auth/session/")
|
||||
|
||||
const shouldRefresh = fetchResponse.data.should_refresh
|
||||
|
||||
if (!shouldRefresh) return
|
||||
|
||||
const refreshResponse = await axiosWithDefaults.put("/auth/session/", {
|
||||
refresh_token: getRefreshToken(),
|
||||
})
|
||||
|
||||
setRefreshToken(refreshResponse.data.refresh_token)
|
||||
}, REFRESH_INTERVAL)
|
||||
|
||||
return () => {
|
||||
if (tokenRefreshLoop) clearInterval(tokenRefreshLoop)
|
||||
tokenRefreshLoop = null
|
||||
}
|
||||
}
|
||||
|
||||
export {
|
||||
setRefreshToken,
|
||||
getRefreshToken,
|
||||
unsetRefreshToken,
|
||||
REFRESH_INTERVAL,
|
||||
REFRESH_TOKEN_KEY,
|
||||
}
|
||||
|
||||
export default setupAuthTokenAutoRefresh
|
|
@ -129,7 +129,9 @@ describe("LoginView", () => {
|
|||
}))
|
||||
const axiosMockAdapter = new AxiosMockAdapter(axios)
|
||||
|
||||
axiosMockAdapter.onPost("/auth/session/").reply(201)
|
||||
axiosMockAdapter
|
||||
.onPost("/auth/session/")
|
||||
.reply(201, { refresh_token: "notatoken" })
|
||||
|
||||
const { user } = renderComponent()
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
*/
|
||||
import { useQueryClient, useMutation } from "@tanstack/react-query"
|
||||
|
||||
import { setRefreshToken, unsetRefreshToken } from "@/authRefresh"
|
||||
import { useLocationContext } from "@/contexts/LocationContext"
|
||||
import axiosWithDefaults from "@/axios"
|
||||
|
||||
|
@ -29,6 +30,7 @@ function useLogout() {
|
|||
},
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({ queryKey: ["current-user"] })
|
||||
unsetRefreshToken()
|
||||
navigate("/login")
|
||||
},
|
||||
})
|
||||
|
@ -61,7 +63,8 @@ function useLogin() {
|
|||
password,
|
||||
})
|
||||
},
|
||||
onSuccess: async () => {
|
||||
onSuccess: async (response) => {
|
||||
setRefreshToken(response.data.refresh_token)
|
||||
await queryClient.refetchQueries({ queryKey: ["current-user"] })
|
||||
navigate("/")
|
||||
},
|
||||
|
|
Reference in a new issue