501 lines
16 KiB
Python
501 lines
16 KiB
Python
"""OAuth Provider Models"""
|
|
import base64
|
|
import binascii
|
|
import json
|
|
import time
|
|
from dataclasses import asdict, dataclass, field
|
|
from hashlib import sha256
|
|
from typing import Any, Dict, List, Optional, Type
|
|
from urllib.parse import urlparse
|
|
from uuid import uuid4
|
|
|
|
from django.conf import settings
|
|
from django.db import models
|
|
from django.forms import ModelForm
|
|
from django.http import HttpRequest
|
|
from django.shortcuts import reverse
|
|
from django.utils import dateformat, timezone
|
|
from django.utils.translation import gettext_lazy as _
|
|
from jwkest.jwk import Key, RSAKey, SYMKey, import_rsa_key
|
|
from jwkest.jws import JWS
|
|
|
|
from passbook.core.models import ExpiringModel, PropertyMapping, Provider, User
|
|
from passbook.crypto.models import CertificateKeyPair
|
|
from passbook.lib.utils.template import render_to_string
|
|
from passbook.lib.utils.time import timedelta_from_string, timedelta_string_validator
|
|
from passbook.providers.oauth2.apps import PassbookProviderOAuth2Config
|
|
from passbook.providers.oauth2.generators import (
|
|
generate_client_id,
|
|
generate_client_secret,
|
|
)
|
|
|
|
|
|
class ClientTypes(models.TextChoices):
|
|
"""Confidential clients are capable of maintaining the confidentiality
|
|
of their credentials. Public clients are incapable."""
|
|
|
|
CONFIDENTIAL = "confidential", _("Confidential")
|
|
PUBLIC = "public", _("Public")
|
|
|
|
|
|
class GrantTypes(models.TextChoices):
|
|
"""OAuth2 Grant types we support"""
|
|
|
|
AUTHORIZATION_CODE = "authorization_code"
|
|
IMPLICIT = "implicit"
|
|
HYBRID = "hybrid"
|
|
|
|
|
|
class SubModes(models.TextChoices):
|
|
"""Mode after which 'sub' attribute is generateed, for compatibility reasons"""
|
|
|
|
HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
|
|
USER_USERNAME = "user_username", _("Based on the username")
|
|
USER_EMAIL = (
|
|
"user_email",
|
|
_("Based on the User's Email. This is recommended over the UPN method."),
|
|
)
|
|
USER_UPN = (
|
|
"user_upn",
|
|
_(
|
|
(
|
|
"Based on the User's UPN, only works if user has a 'upn' attribute set. "
|
|
"Use this method only if you have different UPN and Mail domains."
|
|
)
|
|
),
|
|
)
|
|
|
|
|
|
class ResponseTypes(models.TextChoices):
|
|
"""Response Type required by the client."""
|
|
|
|
CODE = "code", _("code (Authorization Code Flow)")
|
|
CODE_ADFS = (
|
|
"code_adfs",
|
|
_("code (ADFS Compatibility Mode, sends id_token as access_token)"),
|
|
)
|
|
ID_TOKEN = "id_token", _("id_token (Implicit Flow)")
|
|
ID_TOKEN_TOKEN = "id_token token", _("id_token token (Implicit Flow)")
|
|
CODE_TOKEN = "code token", _("code token (Hybrid Flow)")
|
|
CODE_ID_TOKEN = "code id_token", _("code id_token (Hybrid Flow)")
|
|
CODE_ID_TOKEN_TOKEN = "code id_token token", _("code id_token token (Hybrid Flow)")
|
|
|
|
|
|
class JWTAlgorithms(models.TextChoices):
|
|
"""Algorithm used to sign the JWT Token"""
|
|
|
|
HS256 = "HS256", _("HS256 (Symmetric Encryption)")
|
|
RS256 = "RS256", _("RS256 (Asymmetric Encryption)")
|
|
|
|
|
|
class ScopeMapping(PropertyMapping):
|
|
"""Map an OAuth Scope to users properties"""
|
|
|
|
scope_name = models.TextField(help_text=_("Scope used by the client"))
|
|
description = models.TextField(
|
|
blank=True,
|
|
help_text=_(
|
|
(
|
|
"Description shown to the user when consenting. "
|
|
"If left empty, the user won't be informed."
|
|
)
|
|
),
|
|
)
|
|
|
|
def form(self) -> Type[ModelForm]:
|
|
from passbook.providers.oauth2.forms import ScopeMappingForm
|
|
|
|
return ScopeMappingForm
|
|
|
|
def __str__(self):
|
|
return f"Scope Mapping {self.name} ({self.scope_name})"
|
|
|
|
class Meta:
|
|
|
|
verbose_name = _("Scope Mapping")
|
|
verbose_name_plural = _("Scope Mappings")
|
|
|
|
|
|
class OAuth2Provider(Provider):
|
|
"""OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
|
|
|
|
name = models.TextField()
|
|
|
|
client_type = models.CharField(
|
|
max_length=30,
|
|
choices=ClientTypes.choices,
|
|
default=ClientTypes.CONFIDENTIAL,
|
|
verbose_name=_("Client Type"),
|
|
help_text=_(ClientTypes.__doc__),
|
|
)
|
|
client_id = models.CharField(
|
|
max_length=255,
|
|
unique=True,
|
|
verbose_name=_("Client ID"),
|
|
default=generate_client_id,
|
|
)
|
|
client_secret = models.CharField(
|
|
max_length=255,
|
|
blank=True,
|
|
verbose_name=_("Client Secret"),
|
|
default=generate_client_secret,
|
|
)
|
|
response_type = models.TextField(
|
|
choices=ResponseTypes.choices,
|
|
default=ResponseTypes.CODE,
|
|
help_text=_(ResponseTypes.__doc__),
|
|
)
|
|
jwt_alg = models.CharField(
|
|
max_length=10,
|
|
choices=JWTAlgorithms.choices,
|
|
default=JWTAlgorithms.RS256,
|
|
verbose_name=_("JWT Algorithm"),
|
|
help_text=_(JWTAlgorithms.__doc__),
|
|
)
|
|
redirect_uris = models.TextField(
|
|
default="",
|
|
verbose_name=_("Redirect URIs"),
|
|
help_text=_("Enter each URI on a new line."),
|
|
)
|
|
|
|
include_claims_in_id_token = models.BooleanField(
|
|
default=True,
|
|
verbose_name=_("Include claims in id_token"),
|
|
help_text=_(
|
|
(
|
|
"Include User claims from scopes in the id_token, for applications "
|
|
"that don't access the userinfo endpoint."
|
|
)
|
|
),
|
|
)
|
|
|
|
token_validity = models.TextField(
|
|
default="minutes=10",
|
|
validators=[timedelta_string_validator],
|
|
help_text=_(
|
|
(
|
|
"Tokens not valid on or after current time + this value "
|
|
"(Format: hours=1;minutes=2;seconds=3)."
|
|
)
|
|
),
|
|
)
|
|
|
|
sub_mode = models.TextField(
|
|
choices=SubModes.choices,
|
|
default=SubModes.HASHED_USER_ID,
|
|
help_text=_(
|
|
(
|
|
"Configure what data should be used as unique User Identifier. For most cases, "
|
|
"the default should be fine."
|
|
)
|
|
),
|
|
)
|
|
|
|
rsa_key = models.ForeignKey(
|
|
CertificateKeyPair,
|
|
verbose_name=_("RSA Key"),
|
|
on_delete=models.CASCADE,
|
|
blank=True,
|
|
null=True,
|
|
help_text=_(
|
|
"Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
|
|
),
|
|
)
|
|
|
|
@property
|
|
def scope_names(self) -> List[str]:
|
|
"""Return list of assigned scopes seperated with a space"""
|
|
return [pm.scope_name for pm in self.property_mappings.all()]
|
|
|
|
def create_refresh_token(
|
|
self, user: User, scope: List[str], id_token: Optional["IDToken"] = None
|
|
) -> "RefreshToken":
|
|
"""Create and populate a RefreshToken object."""
|
|
token = RefreshToken(
|
|
user=user,
|
|
provider=self,
|
|
access_token=uuid4().hex,
|
|
refresh_token=uuid4().hex,
|
|
expires=timezone.now() + timedelta_from_string(self.token_validity),
|
|
scope=scope,
|
|
)
|
|
if id_token:
|
|
token.id_token = id_token
|
|
return token
|
|
|
|
def get_jwt_keys(self) -> List[Key]:
|
|
"""
|
|
Takes a provider and returns the set of keys associated with it.
|
|
Returns a list of keys.
|
|
"""
|
|
if self.jwt_alg == JWTAlgorithms.RS256:
|
|
# if the user selected RS256 but didn't select a
|
|
# CertificateKeyPair, we fall back to HS256
|
|
if not self.rsa_key:
|
|
self.jwt_alg = JWTAlgorithms.HS256
|
|
self.save()
|
|
else:
|
|
# Because the JWT Library uses python cryptodome,
|
|
# we can't directly pass the RSAPublicKey
|
|
# object, but have to load it ourselves
|
|
key = import_rsa_key(self.rsa_key.key_data)
|
|
keys = [RSAKey(key=key, kid=self.rsa_key.kid)]
|
|
if not keys:
|
|
raise Exception("You must add at least one RSA Key.")
|
|
return keys
|
|
|
|
if self.jwt_alg == JWTAlgorithms.HS256:
|
|
return [SYMKey(key=self.client_secret, alg=self.jwt_alg)]
|
|
|
|
raise Exception("Unsupported key algorithm.")
|
|
|
|
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
|
"""Get issuer, based on request"""
|
|
try:
|
|
mountpoint = PassbookProviderOAuth2Config.mountpoints[
|
|
"passbook.providers.oauth2.urls"
|
|
]
|
|
# pylint: disable=no-member
|
|
return request.build_absolute_uri(f"/{mountpoint}{self.application.slug}/")
|
|
except Provider.application.RelatedObjectDoesNotExist:
|
|
return None
|
|
|
|
@property
|
|
def launch_url(self) -> Optional[str]:
|
|
"""Guess launch_url based on first redirect_uri"""
|
|
if self.redirect_uris == "":
|
|
return None
|
|
main_url = self.redirect_uris.split("\n")[0]
|
|
launch_url = urlparse(main_url)
|
|
return main_url.replace(launch_url.path, "")
|
|
|
|
def form(self) -> Type[ModelForm]:
|
|
from passbook.providers.oauth2.forms import OAuth2ProviderForm
|
|
|
|
return OAuth2ProviderForm
|
|
|
|
def __str__(self):
|
|
return f"OAuth2 Provider {self.name}"
|
|
|
|
def encode(self, payload: Dict[str, Any]) -> str:
|
|
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
|
keys = self.get_jwt_keys()
|
|
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
|
self.refresh_from_db()
|
|
jws = JWS(payload, alg=self.jwt_alg)
|
|
return jws.sign_compact(keys)
|
|
|
|
def html_setup_urls(self, request: HttpRequest) -> Optional[str]:
|
|
"""return template and context modal with URLs for authorize, token, openid-config, etc"""
|
|
try:
|
|
# pylint: disable=no-member
|
|
return render_to_string(
|
|
"providers/oauth2/setup_url_modal.html",
|
|
{
|
|
"provider": self,
|
|
"issuer": self.get_issuer(request),
|
|
"authorize": request.build_absolute_uri(
|
|
reverse("passbook_providers_oauth2:authorize",)
|
|
),
|
|
"token": request.build_absolute_uri(
|
|
reverse("passbook_providers_oauth2:token",)
|
|
),
|
|
"userinfo": request.build_absolute_uri(
|
|
reverse("passbook_providers_oauth2:userinfo",)
|
|
),
|
|
"provider_info": request.build_absolute_uri(
|
|
reverse(
|
|
"passbook_providers_oauth2:provider-info",
|
|
kwargs={"application_slug": self.application.slug},
|
|
)
|
|
),
|
|
},
|
|
)
|
|
except Provider.application.RelatedObjectDoesNotExist:
|
|
return None
|
|
|
|
class Meta:
|
|
|
|
verbose_name = _("OAuth2/OpenID Provider")
|
|
verbose_name_plural = _("OAuth2/OpenID Providers")
|
|
|
|
|
|
class BaseGrantModel(models.Model):
|
|
"""Base Model for all grants"""
|
|
|
|
provider = models.ForeignKey(OAuth2Provider, on_delete=models.CASCADE)
|
|
user = models.ForeignKey(User, verbose_name=_("User"), on_delete=models.CASCADE)
|
|
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
|
|
|
@property
|
|
def scope(self) -> List[str]:
|
|
"""Return scopes as list of strings"""
|
|
return self._scope.split()
|
|
|
|
@scope.setter
|
|
def scope(self, value):
|
|
self._scope = " ".join(value)
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
# pylint: disable=too-many-instance-attributes
|
|
class AuthorizationCode(ExpiringModel, BaseGrantModel):
|
|
"""OAuth2 Authorization Code"""
|
|
|
|
code = models.CharField(max_length=255, unique=True, verbose_name=_("Code"))
|
|
nonce = models.CharField(
|
|
max_length=255, blank=True, default="", verbose_name=_("Nonce")
|
|
)
|
|
is_open_id = models.BooleanField(
|
|
default=False, verbose_name=_("Is Authentication?")
|
|
)
|
|
code_challenge = models.CharField(
|
|
max_length=255, null=True, verbose_name=_("Code Challenge")
|
|
)
|
|
code_challenge_method = models.CharField(
|
|
max_length=255, null=True, verbose_name=_("Code Challenge Method")
|
|
)
|
|
|
|
class Meta:
|
|
verbose_name = _("Authorization Code")
|
|
verbose_name_plural = _("Authorization Codes")
|
|
|
|
def __str__(self):
|
|
return "{0} - {1}".format(self.provider, self.code)
|
|
|
|
|
|
@dataclass
|
|
# pylint: disable=too-many-instance-attributes
|
|
class IDToken:
|
|
"""The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be
|
|
Authenticated is the ID Token data structure. The ID Token is a security token that contains
|
|
Claims about the Authentication of an End-User by an Authorization Server when using a Client,
|
|
and potentially other requested Claims. The ID Token is represented as a
|
|
JSON Web Token (JWT) [JWT].
|
|
|
|
https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
|
|
|
|
# All these fields need to optional so we can save an empty IDToken for non-OpenID flows.
|
|
iss: Optional[str] = None
|
|
sub: Optional[str] = None
|
|
aud: Optional[str] = None
|
|
exp: Optional[int] = None
|
|
iat: Optional[int] = None
|
|
auth_time: Optional[int] = None
|
|
|
|
nonce: Optional[str] = None
|
|
at_hash: Optional[str] = None
|
|
|
|
claims: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "IDToken":
|
|
"""Reconstruct ID Token from json dictionary"""
|
|
token = IDToken()
|
|
for key, value in data.items():
|
|
setattr(token, key, value)
|
|
return token
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert dataclass to dict, and update with keys from `claims`"""
|
|
dic = asdict(self)
|
|
dic.pop("claims")
|
|
dic.update(self.claims)
|
|
return dic
|
|
|
|
|
|
class RefreshToken(ExpiringModel, BaseGrantModel):
|
|
"""OAuth2 Refresh Token"""
|
|
|
|
access_token = models.CharField(
|
|
max_length=255, unique=True, verbose_name=_("Access Token")
|
|
)
|
|
refresh_token = models.CharField(
|
|
max_length=255, unique=True, verbose_name=_("Refresh Token")
|
|
)
|
|
_id_token = models.TextField(verbose_name=_("ID Token"))
|
|
|
|
class Meta:
|
|
verbose_name = _("Token")
|
|
verbose_name_plural = _("Tokens")
|
|
|
|
@property
|
|
def id_token(self) -> IDToken:
|
|
"""Load ID Token from json"""
|
|
if self._id_token:
|
|
raw_token = json.loads(self._id_token)
|
|
return IDToken.from_dict(raw_token)
|
|
return IDToken()
|
|
|
|
@id_token.setter
|
|
def id_token(self, value: IDToken):
|
|
self._id_token = json.dumps(asdict(value))
|
|
|
|
def __str__(self):
|
|
return f"{self.provider} - {self.access_token}"
|
|
|
|
@property
|
|
def at_hash(self):
|
|
"""Get hashed access_token"""
|
|
hashed_access_token = (
|
|
sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii")
|
|
)
|
|
return (
|
|
base64.urlsafe_b64encode(
|
|
binascii.unhexlify(hashed_access_token[: len(hashed_access_token) // 2])
|
|
)
|
|
.rstrip(b"=")
|
|
.decode("ascii")
|
|
)
|
|
|
|
def create_id_token(self, user: User, request: HttpRequest) -> IDToken:
|
|
"""Creates the id_token.
|
|
See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
|
|
sub = ""
|
|
if self.provider.sub_mode == SubModes.HASHED_USER_ID:
|
|
sub = sha256(f"{user.id}-{settings.SECRET_KEY}".encode("ascii")).hexdigest()
|
|
elif self.provider.sub_mode == SubModes.USER_EMAIL:
|
|
sub = user.email
|
|
elif self.provider.sub_mode == SubModes.USER_USERNAME:
|
|
sub = user.username
|
|
elif self.provider.sub_mode == SubModes.USER_UPN:
|
|
sub = user.attributes["upn"]
|
|
else:
|
|
raise ValueError(
|
|
(
|
|
f"Provider {self.provider} has invalid sub_mode "
|
|
f"selected: {self.provider.sub_mode}"
|
|
)
|
|
)
|
|
|
|
# Convert datetimes into timestamps.
|
|
now = int(time.time())
|
|
iat_time = now
|
|
exp_time = int(
|
|
now + timedelta_from_string(self.provider.token_validity).seconds
|
|
)
|
|
user_auth_time = user.last_login or user.date_joined
|
|
auth_time = int(dateformat.format(user_auth_time, "U"))
|
|
|
|
token = IDToken(
|
|
iss=self.provider.get_issuer(request),
|
|
sub=sub,
|
|
aud=self.provider.client_id,
|
|
exp=exp_time,
|
|
iat=iat_time,
|
|
auth_time=auth_time,
|
|
)
|
|
|
|
# Include (or not) user standard claims in the id_token.
|
|
if self.provider.include_claims_in_id_token:
|
|
from passbook.providers.oauth2.views.userinfo import UserInfoView
|
|
|
|
user_info = UserInfoView()
|
|
user_info.request = request
|
|
claims = user_info.get_claims(self)
|
|
token.claims = claims
|
|
|
|
return token
|