providers/oauth2: remember session_id from initial token (#7976) * providers/oauth2: remember session_id original token was created with for future access/refresh tokens * providers/proxy: use hashed session as `sid` --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens L <jens@goauthentik.io>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							4776d2bcc5
						
					
				
				
					commit
					9c8fec21cf
				
			@ -0,0 +1,27 @@
 | 
			
		||||
# Generated by Django 5.0 on 2023-12-22 23:20
 | 
			
		||||
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("authentik_providers_oauth2", "0016_alter_refreshtoken_token"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="accesstoken",
 | 
			
		||||
            name="session_id",
 | 
			
		||||
            field=models.CharField(blank=True, default=""),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="authorizationcode",
 | 
			
		||||
            name="session_id",
 | 
			
		||||
            field=models.CharField(blank=True, default=""),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="refreshtoken",
 | 
			
		||||
            name="session_id",
 | 
			
		||||
            field=models.CharField(blank=True, default=""),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -296,6 +296,7 @@ class BaseGrantModel(models.Model):
 | 
			
		||||
    revoked = models.BooleanField(default=False)
 | 
			
		||||
    _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | 
			
		||||
    auth_time = models.DateTimeField(verbose_name="Authentication time")
 | 
			
		||||
    session_id = models.CharField(default="", blank=True)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def scope(self) -> list[str]:
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""authentik OAuth2 Authorization views"""
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
from json import dumps
 | 
			
		||||
from re import error as RegexError
 | 
			
		||||
from re import fullmatch
 | 
			
		||||
@ -282,6 +283,7 @@ class OAuthAuthorizationParams:
 | 
			
		||||
            expires=now + timedelta_from_string(self.provider.access_code_validity),
 | 
			
		||||
            scope=self.scope,
 | 
			
		||||
            nonce=self.nonce,
 | 
			
		||||
            session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.code_challenge and self.code_challenge_method:
 | 
			
		||||
@ -569,6 +571,7 @@ class OAuthFulfillmentStage(StageView):
 | 
			
		||||
            expires=access_token_expiry,
 | 
			
		||||
            provider=self.provider,
 | 
			
		||||
            auth_time=auth_event.created if auth_event else now,
 | 
			
		||||
            session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        id_token = IDToken.new(self.provider, token, self.request)
 | 
			
		||||
 | 
			
		||||
@ -487,6 +487,7 @@ class TokenView(View):
 | 
			
		||||
            # Keep same scopes as previous token
 | 
			
		||||
            scope=self.params.authorization_code.scope,
 | 
			
		||||
            auth_time=self.params.authorization_code.auth_time,
 | 
			
		||||
            session_id=self.params.authorization_code.session_id,
 | 
			
		||||
        )
 | 
			
		||||
        access_token.id_token = IDToken.new(
 | 
			
		||||
            self.provider,
 | 
			
		||||
@ -502,6 +503,7 @@ class TokenView(View):
 | 
			
		||||
            expires=refresh_token_expiry,
 | 
			
		||||
            provider=self.provider,
 | 
			
		||||
            auth_time=self.params.authorization_code.auth_time,
 | 
			
		||||
            session_id=self.params.authorization_code.session_id,
 | 
			
		||||
        )
 | 
			
		||||
        id_token = IDToken.new(
 | 
			
		||||
            self.provider,
 | 
			
		||||
@ -539,6 +541,7 @@ class TokenView(View):
 | 
			
		||||
            # Keep same scopes as previous token
 | 
			
		||||
            scope=self.params.refresh_token.scope,
 | 
			
		||||
            auth_time=self.params.refresh_token.auth_time,
 | 
			
		||||
            session_id=self.params.refresh_token.session_id,
 | 
			
		||||
        )
 | 
			
		||||
        access_token.id_token = IDToken.new(
 | 
			
		||||
            self.provider,
 | 
			
		||||
@ -554,6 +557,7 @@ class TokenView(View):
 | 
			
		||||
            expires=refresh_token_expiry,
 | 
			
		||||
            provider=self.provider,
 | 
			
		||||
            auth_time=self.params.refresh_token.auth_time,
 | 
			
		||||
            session_id=self.params.refresh_token.session_id,
 | 
			
		||||
        )
 | 
			
		||||
        id_token = IDToken.new(
 | 
			
		||||
            self.provider,
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
"""proxy provider tasks"""
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
 | 
			
		||||
from asgiref.sync import async_to_sync
 | 
			
		||||
from channels.layers import get_channel_layer
 | 
			
		||||
from django.db import DatabaseError, InternalError, ProgrammingError
 | 
			
		||||
@ -23,6 +25,7 @@ def proxy_set_defaults():
 | 
			
		||||
def proxy_on_logout(session_id: str):
 | 
			
		||||
    """Update outpost instances connected to a single outpost"""
 | 
			
		||||
    layer = get_channel_layer()
 | 
			
		||||
    hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
 | 
			
		||||
    for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
 | 
			
		||||
        group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
 | 
			
		||||
        async_to_sync(layer.group_send)(
 | 
			
		||||
@ -30,6 +33,6 @@ def proxy_on_logout(session_id: str):
 | 
			
		||||
            {
 | 
			
		||||
                "type": "event.provider.specific",
 | 
			
		||||
                "sub_type": "logout",
 | 
			
		||||
                "session_id": session_id,
 | 
			
		||||
                "session_id": hashed_session_id,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ entries:
 | 
			
		||||
        # This mapping is used by the authentik proxy. It passes extra user attributes,
 | 
			
		||||
        # which are used for example for the HTTP-Basic Authentication mapping.
 | 
			
		||||
        return {
 | 
			
		||||
            "sid": request.http_request.session.session_key,
 | 
			
		||||
            "sid": token.session_id,
 | 
			
		||||
            "ak_proxy": {
 | 
			
		||||
                "user_attributes": request.user.group_attributes(request),
 | 
			
		||||
                "is_superuser": request.user.is_superuser,
 | 
			
		||||
 | 
			
		||||
@ -36,6 +36,7 @@ func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]inte
 | 
			
		||||
	switch msg.SubType {
 | 
			
		||||
	case WSProviderSubTypeLogout:
 | 
			
		||||
		for _, p := range ps.apps {
 | 
			
		||||
			ps.log.WithField("provider", p.Host).Debug("Logging out")
 | 
			
		||||
			err := p.Logout(ctx, func(c application.Claims) bool {
 | 
			
		||||
				return c.Sid == msg.SessionID
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user