api: make 401 messages clearer
closes #755 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
837d2f6fab
commit
464a1c0536
|
@ -4,6 +4,7 @@ from binascii import Error
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
||||||
|
from rest_framework.exceptions import AuthenticationFailed
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ LOGGER = get_logger()
|
||||||
|
|
||||||
# pylint: disable=too-many-return-statements
|
# pylint: disable=too-many-return-statements
|
||||||
def token_from_header(raw_header: bytes) -> Optional[Token]:
|
def token_from_header(raw_header: bytes) -> Optional[Token]:
|
||||||
"""raw_header in the Format of `Basic dGVzdDp0ZXN0`"""
|
"""raw_header in the Format of `Bearer dGVzdDp0ZXN0`"""
|
||||||
auth_credentials = raw_header.decode()
|
auth_credentials = raw_header.decode()
|
||||||
if auth_credentials == "":
|
if auth_credentials == "":
|
||||||
return None
|
return None
|
||||||
|
@ -25,28 +26,27 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
|
||||||
auth_type, body = plain.split()
|
auth_type, body = plain.split()
|
||||||
auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}"
|
auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}"
|
||||||
except (UnicodeDecodeError, Error):
|
except (UnicodeDecodeError, Error):
|
||||||
return None
|
raise AuthenticationFailed("Malformed header")
|
||||||
auth_type, auth_credentials = auth_credentials.split()
|
auth_type, auth_credentials = auth_credentials.split()
|
||||||
if auth_type.lower() not in ["basic", "bearer"]:
|
if auth_type.lower() not in ["basic", "bearer"]:
|
||||||
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())
|
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())
|
||||||
return None
|
raise AuthenticationFailed("Unsupported authentication type")
|
||||||
password = auth_credentials
|
password = auth_credentials
|
||||||
if auth_type.lower() == "basic":
|
if auth_type.lower() == "basic":
|
||||||
try:
|
try:
|
||||||
auth_credentials = b64decode(auth_credentials.encode()).decode()
|
auth_credentials = b64decode(auth_credentials.encode()).decode()
|
||||||
except (UnicodeDecodeError, Error):
|
except (UnicodeDecodeError, Error):
|
||||||
return None
|
raise AuthenticationFailed("Malformed header")
|
||||||
# Accept credentials with username and without
|
# Accept credentials with username and without
|
||||||
if ":" in auth_credentials:
|
if ":" in auth_credentials:
|
||||||
_, password = auth_credentials.split(":")
|
_, password = auth_credentials.split(":")
|
||||||
else:
|
else:
|
||||||
password = auth_credentials
|
password = auth_credentials
|
||||||
if password == "": # nosec
|
if password == "": # nosec
|
||||||
return None
|
raise AuthenticationFailed("Malformed header")
|
||||||
tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API)
|
tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API)
|
||||||
if not tokens.exists():
|
if not tokens.exists():
|
||||||
LOGGER.debug("Token not found")
|
raise AuthenticationFailed("Token invalid/expired")
|
||||||
return None
|
|
||||||
return tokens.first()
|
return tokens.first()
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ class AuthentikTokenAuthentication(BaseAuthentication):
|
||||||
auth = get_authorization_header(request)
|
auth = get_authorization_header(request)
|
||||||
|
|
||||||
token = token_from_header(auth)
|
token = token_from_header(auth)
|
||||||
|
# None is only returned when the header isn't set.
|
||||||
if not token:
|
if not token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from base64 import b64encode
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
|
from rest_framework.exceptions import AuthenticationFailed
|
||||||
|
|
||||||
from authentik.api.auth import token_from_header
|
from authentik.api.auth import token_from_header
|
||||||
from authentik.core.models import Token, TokenIntents
|
from authentik.core.models import Token, TokenIntents
|
||||||
|
@ -28,17 +29,21 @@ class TestAPIAuth(TestCase):
|
||||||
|
|
||||||
def test_invalid_type(self):
|
def test_invalid_type(self):
|
||||||
"""Test invalid type"""
|
"""Test invalid type"""
|
||||||
self.assertIsNone(token_from_header("foo bar".encode()))
|
with self.assertRaises(AuthenticationFailed):
|
||||||
|
token_from_header("foo bar".encode())
|
||||||
|
|
||||||
def test_invalid_decode(self):
|
def test_invalid_decode(self):
|
||||||
"""Test invalid bas64"""
|
"""Test invalid bas64"""
|
||||||
self.assertIsNone(token_from_header("Basic bar".encode()))
|
with self.assertRaises(AuthenticationFailed):
|
||||||
|
token_from_header("Basic bar".encode())
|
||||||
|
|
||||||
def test_invalid_empty_password(self):
|
def test_invalid_empty_password(self):
|
||||||
"""Test invalid with empty password"""
|
"""Test invalid with empty password"""
|
||||||
self.assertIsNone(token_from_header("Basic :".encode()))
|
with self.assertRaises(AuthenticationFailed):
|
||||||
|
token_from_header("Basic :".encode())
|
||||||
|
|
||||||
def test_invalid_no_token(self):
|
def test_invalid_no_token(self):
|
||||||
"""Test invalid with no token"""
|
"""Test invalid with no token"""
|
||||||
|
with self.assertRaises(AuthenticationFailed):
|
||||||
auth = b64encode(":abc".encode()).decode()
|
auth = b64encode(":abc".encode()).decode()
|
||||||
self.assertIsNone(token_from_header(f"Basic :{auth}".encode()))
|
self.assertIsNone(token_from_header(f"Basic :{auth}".encode()))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Channels base classes"""
|
"""Channels base classes"""
|
||||||
from channels.exceptions import DenyConnection
|
from channels.exceptions import DenyConnection
|
||||||
from channels.generic.websocket import JsonWebsocketConsumer
|
from channels.generic.websocket import JsonWebsocketConsumer
|
||||||
|
from rest_framework.exceptions import AuthenticationFailed
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.api.auth import token_from_header
|
from authentik.api.auth import token_from_header
|
||||||
|
@ -22,9 +23,13 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
|
||||||
|
|
||||||
raw_header = headers[b"authorization"]
|
raw_header = headers[b"authorization"]
|
||||||
|
|
||||||
|
try:
|
||||||
token = token_from_header(raw_header)
|
token = token_from_header(raw_header)
|
||||||
|
# token is only None when no header was given, in which case we deny too
|
||||||
if not token:
|
if not token:
|
||||||
LOGGER.warning("Failed to authenticate")
|
raise DenyConnection()
|
||||||
|
except AuthenticationFailed as exc:
|
||||||
|
LOGGER.warning("Failed to authenticate", exc=exc)
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
|
|
||||||
self.user = token.user
|
self.user = token.user
|
||||||
|
|
Reference in New Issue