538 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			538 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """OAuth Provider Models"""
 | |
| import base64
 | |
| import binascii
 | |
| import json
 | |
| import time
 | |
| from dataclasses import asdict, dataclass, field
 | |
| from hashlib import sha256
 | |
| from typing import Any, Dict, List, Optional, Type
 | |
| from urllib.parse import urlparse
 | |
| from uuid import uuid4
 | |
| 
 | |
| from dacite import from_dict
 | |
| from django.conf import settings
 | |
| from django.db import models
 | |
| from django.forms import ModelForm
 | |
| from django.http import HttpRequest
 | |
| from django.shortcuts import reverse
 | |
| from django.utils import dateformat, timezone
 | |
| from django.utils.translation import gettext_lazy as _
 | |
| from jwkest.jwk import Key, RSAKey, SYMKey, import_rsa_key
 | |
| from jwkest.jws import JWS
 | |
| from rest_framework.serializers import Serializer
 | |
| 
 | |
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
 | |
| from authentik.crypto.models import CertificateKeyPair
 | |
| from authentik.events.models import Event, EventAction
 | |
| from authentik.events.utils import get_user
 | |
| from authentik.lib.utils.template import render_to_string
 | |
| from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
 | |
| from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config
 | |
| from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
 | |
| from authentik.providers.oauth2.generators import (
 | |
|     generate_client_id,
 | |
|     generate_client_secret,
 | |
| )
 | |
| 
 | |
| 
 | |
| class ClientTypes(models.TextChoices):
 | |
|     """Confidential clients are capable of maintaining the confidentiality
 | |
|     of their credentials. Public clients are incapable."""
 | |
| 
 | |
|     CONFIDENTIAL = "confidential", _("Confidential")
 | |
|     PUBLIC = "public", _("Public")
 | |
| 
 | |
| 
 | |
| class GrantTypes(models.TextChoices):
 | |
|     """OAuth2 Grant types we support"""
 | |
| 
 | |
|     AUTHORIZATION_CODE = "authorization_code"
 | |
|     IMPLICIT = "implicit"
 | |
|     HYBRID = "hybrid"
 | |
| 
 | |
| 
 | |
| class SubModes(models.TextChoices):
 | |
|     """Mode after which 'sub' attribute is generateed, for compatibility reasons"""
 | |
| 
 | |
|     HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
 | |
|     USER_USERNAME = "user_username", _("Based on the username")
 | |
|     USER_EMAIL = (
 | |
|         "user_email",
 | |
|         _("Based on the User's Email. This is recommended over the UPN method."),
 | |
|     )
 | |
|     USER_UPN = (
 | |
|         "user_upn",
 | |
|         _(
 | |
|             (
 | |
|                 "Based on the User's UPN, only works if user has a 'upn' attribute set. "
 | |
|                 "Use this method only if you have different UPN and Mail domains."
 | |
|             )
 | |
|         ),
 | |
|     )
 | |
| 
 | |
| 
 | |
| class IssuerMode(models.TextChoices):
 | |
|     """Configure how the `iss` field is created."""
 | |
| 
 | |
|     GLOBAL = "global", _("Same identifier is used for all providers")
 | |
|     PER_PROVIDER = "per_provider", _(
 | |
|         "Each provider has a different issuer, based on the application slug."
 | |
|     )
 | |
| 
 | |
| 
 | |
| class ResponseTypes(models.TextChoices):
 | |
|     """Response Type required by the client."""
 | |
| 
 | |
|     CODE = "code", _("code (Authorization Code Flow)")
 | |
|     ID_TOKEN = "id_token", _("id_token (Implicit Flow)")
 | |
|     ID_TOKEN_TOKEN = "id_token token", _("id_token token (Implicit Flow)")
 | |
|     CODE_TOKEN = "code token", _("code token (Hybrid Flow)")
 | |
|     CODE_ID_TOKEN = "code id_token", _("code id_token (Hybrid Flow)")
 | |
|     CODE_ID_TOKEN_TOKEN = "code id_token token", _("code id_token token (Hybrid Flow)")
 | |
| 
 | |
| 
 | |
| class JWTAlgorithms(models.TextChoices):
 | |
|     """Algorithm used to sign the JWT Token"""
 | |
| 
 | |
|     HS256 = "HS256", _("HS256 (Symmetric Encryption)")
 | |
|     RS256 = "RS256", _("RS256 (Asymmetric Encryption)")
 | |
| 
 | |
| 
 | |
| class ScopeMapping(PropertyMapping):
 | |
|     """Map an OAuth Scope to users properties"""
 | |
| 
 | |
|     scope_name = models.TextField(help_text=_("Scope used by the client"))
 | |
|     description = models.TextField(
 | |
|         blank=True,
 | |
|         help_text=_(
 | |
|             (
 | |
|                 "Description shown to the user when consenting. "
 | |
|                 "If left empty, the user won't be informed."
 | |
|             )
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     @property
 | |
|     def form(self) -> Type[ModelForm]:
 | |
|         from authentik.providers.oauth2.forms import ScopeMappingForm
 | |
| 
 | |
|         return ScopeMappingForm
 | |
| 
 | |
|     def __str__(self):
 | |
|         return f"Scope Mapping {self.name} ({self.scope_name})"
 | |
| 
 | |
|     class Meta:
 | |
| 
 | |
|         verbose_name = _("Scope Mapping")
 | |
|         verbose_name_plural = _("Scope Mappings")
 | |
| 
 | |
| 
 | |
| class OAuth2Provider(Provider):
 | |
|     """OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
 | |
| 
 | |
|     client_type = models.CharField(
 | |
|         max_length=30,
 | |
|         choices=ClientTypes.choices,
 | |
|         default=ClientTypes.CONFIDENTIAL,
 | |
|         verbose_name=_("Client Type"),
 | |
|         help_text=_(ClientTypes.__doc__),
 | |
|     )
 | |
|     client_id = models.CharField(
 | |
|         max_length=255,
 | |
|         unique=True,
 | |
|         verbose_name=_("Client ID"),
 | |
|         default=generate_client_id,
 | |
|     )
 | |
|     client_secret = models.CharField(
 | |
|         max_length=255,
 | |
|         blank=True,
 | |
|         verbose_name=_("Client Secret"),
 | |
|         default=generate_client_secret,
 | |
|     )
 | |
|     jwt_alg = models.CharField(
 | |
|         max_length=10,
 | |
|         choices=JWTAlgorithms.choices,
 | |
|         default=JWTAlgorithms.RS256,
 | |
|         verbose_name=_("JWT Algorithm"),
 | |
|         help_text=_(JWTAlgorithms.__doc__),
 | |
|     )
 | |
|     redirect_uris = models.TextField(
 | |
|         default="",
 | |
|         verbose_name=_("Redirect URIs"),
 | |
|         help_text=_("Enter each URI on a new line."),
 | |
|     )
 | |
| 
 | |
|     include_claims_in_id_token = models.BooleanField(
 | |
|         default=True,
 | |
|         verbose_name=_("Include claims in id_token"),
 | |
|         help_text=_(
 | |
|             (
 | |
|                 "Include User claims from scopes in the id_token, for applications "
 | |
|                 "that don't access the userinfo endpoint."
 | |
|             )
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     token_validity = models.TextField(
 | |
|         default="minutes=10",
 | |
|         validators=[timedelta_string_validator],
 | |
|         help_text=_(
 | |
|             (
 | |
|                 "Tokens not valid on or after current time + this value "
 | |
|                 "(Format: hours=1;minutes=2;seconds=3)."
 | |
|             )
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     sub_mode = models.TextField(
 | |
|         choices=SubModes.choices,
 | |
|         default=SubModes.HASHED_USER_ID,
 | |
|         help_text=_(
 | |
|             (
 | |
|                 "Configure what data should be used as unique User Identifier. For most cases, "
 | |
|                 "the default should be fine."
 | |
|             )
 | |
|         ),
 | |
|     )
 | |
|     issuer_mode = models.TextField(
 | |
|         choices=IssuerMode.choices,
 | |
|         default=IssuerMode.PER_PROVIDER,
 | |
|         help_text=_(
 | |
|             ("Configure how the issuer field of the ID Token should be filled.")
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     rsa_key = models.ForeignKey(
 | |
|         CertificateKeyPair,
 | |
|         verbose_name=_("RSA Key"),
 | |
|         on_delete=models.CASCADE,
 | |
|         blank=True,
 | |
|         null=True,
 | |
|         help_text=_(
 | |
|             "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     def create_refresh_token(
 | |
|         self, user: User, scope: List[str], request: HttpRequest
 | |
|     ) -> "RefreshToken":
 | |
|         """Create and populate a RefreshToken object."""
 | |
|         token = RefreshToken(
 | |
|             user=user,
 | |
|             provider=self,
 | |
|             refresh_token=uuid4().hex,
 | |
|             expires=timezone.now() + timedelta_from_string(self.token_validity),
 | |
|             scope=scope,
 | |
|         )
 | |
|         token.access_token = token.create_access_token(user, request)
 | |
|         return token
 | |
| 
 | |
|     def get_jwt_keys(self) -> List[Key]:
 | |
|         """
 | |
|         Takes a provider and returns the set of keys associated with it.
 | |
|         Returns a list of keys.
 | |
|         """
 | |
|         if self.jwt_alg == JWTAlgorithms.RS256:
 | |
|             # if the user selected RS256 but didn't select a
 | |
|             # CertificateKeyPair, we fall back to HS256
 | |
|             if not self.rsa_key:
 | |
|                 Event.new(
 | |
|                     EventAction.CONFIGURATION_ERROR,
 | |
|                     provider=self,
 | |
|                     message="Provider was configured for RS256, but no key was selected.",
 | |
|                 ).save()
 | |
|                 self.jwt_alg = JWTAlgorithms.HS256
 | |
|                 self.save()
 | |
|             else:
 | |
|                 # Because the JWT Library uses python cryptodome,
 | |
|                 # we can't directly pass the RSAPublicKey
 | |
|                 # object, but have to load it ourselves
 | |
|                 key = import_rsa_key(self.rsa_key.key_data)
 | |
|                 keys = [RSAKey(key=key, kid=self.rsa_key.kid)]
 | |
|                 if not keys:
 | |
|                     raise Exception("You must add at least one RSA Key.")
 | |
|                 return keys
 | |
| 
 | |
|         if self.jwt_alg == JWTAlgorithms.HS256:
 | |
|             return [SYMKey(key=self.client_secret, alg=self.jwt_alg)]
 | |
| 
 | |
|         raise Exception("Unsupported key algorithm.")
 | |
| 
 | |
|     def get_issuer(self, request: HttpRequest) -> Optional[str]:
 | |
|         """Get issuer, based on request"""
 | |
|         if self.issuer_mode == IssuerMode.GLOBAL:
 | |
|             return request.build_absolute_uri("/")
 | |
|         try:
 | |
|             mountpoint = AuthentikProviderOAuth2Config.mountpoints[
 | |
|                 "authentik.providers.oauth2.urls"
 | |
|             ]
 | |
|             # pylint: disable=no-member
 | |
|             return request.build_absolute_uri(f"/{mountpoint}{self.application.slug}/")
 | |
|         except Provider.application.RelatedObjectDoesNotExist:
 | |
|             return None
 | |
| 
 | |
|     @property
 | |
|     def launch_url(self) -> Optional[str]:
 | |
|         """Guess launch_url based on first redirect_uri"""
 | |
|         if self.redirect_uris == "":
 | |
|             return None
 | |
|         main_url = self.redirect_uris.split("\n")[0]
 | |
|         launch_url = urlparse(main_url)
 | |
|         return main_url.replace(launch_url.path, "")
 | |
| 
 | |
|     @property
 | |
|     def serializer(self) -> Type[Serializer]:
 | |
|         from authentik.providers.oauth2.api import OAuth2ProviderSerializer
 | |
| 
 | |
|         return OAuth2ProviderSerializer
 | |
| 
 | |
|     @property
 | |
|     def form(self) -> Type[ModelForm]:
 | |
|         from authentik.providers.oauth2.forms import OAuth2ProviderForm
 | |
| 
 | |
|         return OAuth2ProviderForm
 | |
| 
 | |
|     def __str__(self):
 | |
|         return f"OAuth2 Provider {self.name}"
 | |
| 
 | |
|     def encode(self, payload: Dict[str, Any]) -> str:
 | |
|         """Represent the ID Token as a JSON Web Token (JWT)."""
 | |
|         keys = self.get_jwt_keys()
 | |
|         # If the provider does not have an RSA Key assigned, it was switched to Symmetric
 | |
|         self.refresh_from_db()
 | |
|         jws = JWS(payload, alg=self.jwt_alg)
 | |
|         return jws.sign_compact(keys)
 | |
| 
 | |
|     def html_setup_urls(self, request: HttpRequest) -> Optional[str]:
 | |
|         """return template and context modal with URLs for authorize, token, openid-config, etc"""
 | |
|         try:
 | |
|             # pylint: disable=no-member
 | |
|             return render_to_string(
 | |
|                 "providers/oauth2/setup_url_modal.html",
 | |
|                 {
 | |
|                     "provider": self,
 | |
|                     "issuer": self.get_issuer(request),
 | |
|                     "authorize": request.build_absolute_uri(
 | |
|                         reverse(
 | |
|                             "authentik_providers_oauth2:authorize",
 | |
|                         )
 | |
|                     ),
 | |
|                     "token": request.build_absolute_uri(
 | |
|                         reverse(
 | |
|                             "authentik_providers_oauth2:token",
 | |
|                         )
 | |
|                     ),
 | |
|                     "userinfo": request.build_absolute_uri(
 | |
|                         reverse(
 | |
|                             "authentik_providers_oauth2:userinfo",
 | |
|                         )
 | |
|                     ),
 | |
|                     "provider_info": request.build_absolute_uri(
 | |
|                         reverse(
 | |
|                             "authentik_providers_oauth2:provider-info",
 | |
|                             kwargs={"application_slug": self.application.slug},
 | |
|                         )
 | |
|                     ),
 | |
|                 },
 | |
|             )
 | |
|         except Provider.application.RelatedObjectDoesNotExist:
 | |
|             return None
 | |
| 
 | |
|     class Meta:
 | |
| 
 | |
|         verbose_name = _("OAuth2/OpenID Provider")
 | |
|         verbose_name_plural = _("OAuth2/OpenID Providers")
 | |
| 
 | |
| 
 | |
| class BaseGrantModel(models.Model):
 | |
|     """Base Model for all grants"""
 | |
| 
 | |
|     provider = models.ForeignKey(OAuth2Provider, on_delete=models.CASCADE)
 | |
|     user = models.ForeignKey(User, verbose_name=_("User"), on_delete=models.CASCADE)
 | |
|     _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | |
| 
 | |
|     @property
 | |
|     def scope(self) -> List[str]:
 | |
|         """Return scopes as list of strings"""
 | |
|         return self._scope.split()
 | |
| 
 | |
|     @scope.setter
 | |
|     def scope(self, value):
 | |
|         self._scope = " ".join(value)
 | |
| 
 | |
|     class Meta:
 | |
|         abstract = True
 | |
| 
 | |
| 
 | |
| class AuthorizationCode(ExpiringModel, BaseGrantModel):
 | |
|     """OAuth2 Authorization Code"""
 | |
| 
 | |
|     code = models.CharField(max_length=255, unique=True, verbose_name=_("Code"))
 | |
|     nonce = models.CharField(
 | |
|         max_length=255, blank=True, default="", verbose_name=_("Nonce")
 | |
|     )
 | |
|     is_open_id = models.BooleanField(
 | |
|         default=False, verbose_name=_("Is Authentication?")
 | |
|     )
 | |
|     code_challenge = models.CharField(
 | |
|         max_length=255, null=True, verbose_name=_("Code Challenge")
 | |
|     )
 | |
|     code_challenge_method = models.CharField(
 | |
|         max_length=255, null=True, verbose_name=_("Code Challenge Method")
 | |
|     )
 | |
| 
 | |
|     @property
 | |
|     def c_hash(self):
 | |
|         """https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
 | |
|         hashed_code = sha256(self.code.encode("ascii")).hexdigest().encode("ascii")
 | |
|         return (
 | |
|             base64.urlsafe_b64encode(
 | |
|                 binascii.unhexlify(hashed_code[: len(hashed_code) // 2])
 | |
|             )
 | |
|             .rstrip(b"=")
 | |
|             .decode("ascii")
 | |
|         )
 | |
| 
 | |
|     class Meta:
 | |
|         verbose_name = _("Authorization Code")
 | |
|         verbose_name_plural = _("Authorization Codes")
 | |
| 
 | |
|     def __str__(self):
 | |
|         return f"Authorization code for {self.provider} for user {self.user}"
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class IDToken:
 | |
|     """The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be
 | |
|     Authenticated is the ID Token data structure. The ID Token is a security token that contains
 | |
|     Claims about the Authentication of an End-User by an Authorization Server when using a Client,
 | |
|     and potentially other requested Claims. The ID Token is represented as a
 | |
|     JSON Web Token (JWT) [JWT].
 | |
| 
 | |
|     https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
 | |
| 
 | |
|     # All these fields need to optional so we can save an empty IDToken for non-OpenID flows.
 | |
|     iss: Optional[str] = None
 | |
|     sub: Optional[str] = None
 | |
|     aud: Optional[str] = None
 | |
|     exp: Optional[int] = None
 | |
|     iat: Optional[int] = None
 | |
|     auth_time: Optional[int] = None
 | |
|     acr: Optional[str] = ACR_AUTHENTIK_DEFAULT
 | |
| 
 | |
|     c_hash: Optional[str] = None
 | |
| 
 | |
|     nonce: Optional[str] = None
 | |
|     at_hash: Optional[str] = None
 | |
| 
 | |
|     claims: Dict[str, Any] = field(default_factory=dict)
 | |
| 
 | |
|     def to_dict(self) -> Dict[str, Any]:
 | |
|         """Convert dataclass to dict, and update with keys from `claims`"""
 | |
|         dic = asdict(self)
 | |
|         dic.pop("claims")
 | |
|         dic.update(self.claims)
 | |
|         return dic
 | |
| 
 | |
| 
 | |
| class RefreshToken(ExpiringModel, BaseGrantModel):
 | |
|     """OAuth2 Refresh Token"""
 | |
| 
 | |
|     access_token = models.TextField(verbose_name=_("Access Token"))
 | |
|     refresh_token = models.CharField(
 | |
|         max_length=255, unique=True, verbose_name=_("Refresh Token")
 | |
|     )
 | |
|     _id_token = models.TextField(verbose_name=_("ID Token"))
 | |
| 
 | |
|     class Meta:
 | |
|         verbose_name = _("OAuth2 Token")
 | |
|         verbose_name_plural = _("OAuth2 Tokens")
 | |
| 
 | |
|     @property
 | |
|     def id_token(self) -> IDToken:
 | |
|         """Load ID Token from json"""
 | |
|         if self._id_token:
 | |
|             raw_token = json.loads(self._id_token)
 | |
|             return from_dict(IDToken, raw_token)
 | |
|         return IDToken()
 | |
| 
 | |
|     @id_token.setter
 | |
|     def id_token(self, value: IDToken):
 | |
|         self._id_token = json.dumps(asdict(value))
 | |
| 
 | |
|     def __str__(self):
 | |
|         return f"Refresh Token for {self.provider} for user {self.user}"
 | |
| 
 | |
|     @property
 | |
|     def at_hash(self):
 | |
|         """Get hashed access_token"""
 | |
|         hashed_access_token = (
 | |
|             sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii")
 | |
|         )
 | |
|         return (
 | |
|             base64.urlsafe_b64encode(
 | |
|                 binascii.unhexlify(hashed_access_token[: len(hashed_access_token) // 2])
 | |
|             )
 | |
|             .rstrip(b"=")
 | |
|             .decode("ascii")
 | |
|         )
 | |
| 
 | |
|     def create_access_token(self, user: User, request: HttpRequest) -> str:
 | |
|         """Create access token with a similar format as Okta, Keycloak, ADFS"""
 | |
|         token = self.create_id_token(user, request).to_dict()
 | |
|         token["cid"] = self.provider.client_id
 | |
|         token["uid"] = uuid4().hex
 | |
|         return self.provider.encode(token)
 | |
| 
 | |
|     def create_id_token(self, user: User, request: HttpRequest) -> IDToken:
 | |
|         """Creates the id_token.
 | |
|         See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
 | |
|         sub = ""
 | |
|         if self.provider.sub_mode == SubModes.HASHED_USER_ID:
 | |
|             sub = sha256(f"{user.id}-{settings.SECRET_KEY}".encode("ascii")).hexdigest()
 | |
|         elif self.provider.sub_mode == SubModes.USER_EMAIL:
 | |
|             sub = user.email
 | |
|         elif self.provider.sub_mode == SubModes.USER_USERNAME:
 | |
|             sub = user.username
 | |
|         elif self.provider.sub_mode == SubModes.USER_UPN:
 | |
|             sub = user.attributes["upn"]
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 (
 | |
|                     f"Provider {self.provider} has invalid sub_mode "
 | |
|                     f"selected: {self.provider.sub_mode}"
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         # Convert datetimes into timestamps.
 | |
|         now = int(time.time())
 | |
|         iat_time = now
 | |
|         exp_time = int(
 | |
|             now + timedelta_from_string(self.provider.token_validity).seconds
 | |
|         )
 | |
|         # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
 | |
|         auth_event = Event.objects.filter(
 | |
|             action=EventAction.LOGIN, user=get_user(user)
 | |
|         ).latest("created")
 | |
|         auth_time = int(dateformat.format(auth_event.created, "U"))
 | |
| 
 | |
|         token = IDToken(
 | |
|             iss=self.provider.get_issuer(request),
 | |
|             sub=sub,
 | |
|             aud=self.provider.client_id,
 | |
|             exp=exp_time,
 | |
|             iat=iat_time,
 | |
|             auth_time=auth_time,
 | |
|         )
 | |
| 
 | |
|         # Include (or not) user standard claims in the id_token.
 | |
|         if self.provider.include_claims_in_id_token:
 | |
|             from authentik.providers.oauth2.views.userinfo import UserInfoView
 | |
| 
 | |
|             user_info = UserInfoView()
 | |
|             user_info.request = request
 | |
|             claims = user_info.get_claims(self)
 | |
|             token.claims = claims
 | |
| 
 | |
|         return token
 | 
