Files
authentik/authentik/enterprise/providers/apple_psso/views/token.py
Jens Langhammer 1cca629464 fix session
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-30 21:31:16 +02:00

141 lines
5.4 KiB
Python

from datetime import timedelta
from django.http import Http404, HttpRequest, HttpResponse
from django.utils.decorators import method_decorator
from django.utils.timezone import now
from django.views import View
from django.views.decorators.csrf import csrf_exempt
from jwt import PyJWT, decode
from rest_framework.exceptions import ValidationError
from structlog.stdlib import get_logger
from authentik.core.models import AuthenticatedSession, Session, User
from authentik.core.sessions import SessionStore
from authentik.enterprise.providers.apple_psso.http import JWEResponse
from authentik.enterprise.providers.apple_psso.models import (
AppleDevice,
AppleDeviceUser,
AppleNonce,
ApplePlatformSSOProvider,
)
from authentik.events.models import Event, EventAction
from authentik.events.signals import SESSION_LOGIN_EVENT
from authentik.providers.oauth2.constants import TOKEN_TYPE
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import RefreshToken
from authentik.root.middleware import SessionMiddleware
LOGGER = get_logger()
@method_decorator(csrf_exempt, name="dispatch")
class TokenView(View):
device: AppleDevice
provider: ApplePlatformSSOProvider
def post(self, request: HttpRequest) -> HttpResponse:
version = request.POST.get("platform_sso_version")
assertion = request.POST.get("assertion", request.POST.get("request"))
if not assertion:
return HttpResponse(status=400)
decode_unvalidated = PyJWT().decode_complete(assertion, options={"verify_signature": False})
LOGGER.debug(decode_unvalidated["header"])
expected_kid = decode_unvalidated["header"]["kid"]
self.device = AppleDevice.objects.filter(sign_key_id=expected_kid).first()
if not self.device:
raise Http404
self.provider = self.device.provider
# Properly decode the JWT with the key from the device
decoded = decode(
assertion, self.device.signing_key, algorithms=["ES256"], options={"verify_aud": False}
)
LOGGER.debug(decoded)
LOGGER.debug("got device", device=self.device)
# Check that the nonce hasn't been used before
nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
if not nonce:
return HttpResponse(status=400)
nonce.delete()
handler_func = (
f"handle_v{version}_{decode_unvalidated["header"]["typ"]}".replace("-", "_")
.replace("+", "_")
.replace(".", "_")
)
handler = getattr(self, handler_func, None)
if not handler:
LOGGER.debug("Handler not found", handler=handler_func)
return HttpResponse(status=400)
LOGGER.debug("sending to handler", handler=handler_func)
return handler(decoded)
def validate_device_user_response(self, assertion: str) -> tuple[AppleDeviceUser, dict] | None:
"""Decode an embedded assertion and validate it by looking up the matching device user"""
decode_unvalidated = PyJWT().decode_complete(assertion, options={"verify_signature": False})
expected_kid = decode_unvalidated["header"]["kid"]
device_user = AppleDeviceUser.objects.filter(
device=self.device, enclave_key_id=expected_kid
).first()
if not device_user:
return None
return device_user, decode(
assertion,
device_user.secure_enclave_key,
audience="apple-platform-sso",
algorithms=["ES256"],
)
def create_auth_session(self, user: User):
event = Event.new(EventAction.LOGIN).from_http(self.request, user=user)
store = SessionStore()
store[SESSION_LOGIN_EVENT] = event
store.save()
session = Session.objects.filter(session_key=store.session_key).first()
AuthenticatedSession.objects.create(session=session, user=user)
session = SessionMiddleware.encode_session(store.session_key, user)
return session
def handle_v1_0_platformsso_login_request_jwt(self, decoded: dict):
user = None
if decoded["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer":
# Decode and validate inner assertion
user, inner = self.validate_device_user_response(decoded["assertion"])
if inner["nonce"] != decoded["nonce"]:
LOGGER.warning("Mis-matched nonce to outer assertion")
raise ValidationError("Invalid request")
refresh_token = RefreshToken(
user=user.user,
scope=decoded["scope"],
expires=now() + timedelta(hours=8),
provider=self.provider,
auth_time=now(),
session=None,
)
id_token = IDToken.new(
self.provider,
refresh_token,
self.request,
)
id_token.nonce = decoded["nonce"]
refresh_token.id_token = id_token
refresh_token.save()
return JWEResponse(
{
"refresh_token": refresh_token.token,
"refresh_token_expires_in": int((refresh_token.expires - now()).total_seconds()),
"id_token": refresh_token.id_token.to_jwt(self.provider),
"token_type": TOKEN_TYPE,
"session_key": self.create_auth_session(user.user),
},
device=self.device,
apv=decoded["jwe_crypto"]["apv"],
)