e2e: add test for OAuth Enrollment -> OAuth Authentication

This commit is contained in:
Jens Langhammer
2020-07-10 00:14:48 +02:00
parent 4caa4be476
commit 7ac4242a38
2 changed files with 97 additions and 59 deletions

View File

@ -48,59 +48,55 @@ class OAuthCallback(OAuthClientMixin, View):
self.source = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist:
raise Http404(f"Unknown OAuth source '{slug}'.")
else:
if not self.source.enabled:
raise Http404(f"Source {slug} is not enabled.")
client = self.get_client(self.source)
callback = self.get_callback_url(self.source)
# Fetch access token
token = client.get_access_token(self.request, callback=callback)
if token is None:
return self.handle_login_failure(
self.source, "Could not retrieve token."
)
if "error" in token:
return self.handle_login_failure(self.source, token["error"])
# Fetch profile info
info = client.get_profile_info(token)
if info is None:
return self.handle_login_failure(
self.source, "Could not retrieve profile."
)
identifier = self.get_user_id(self.source, info)
if identifier is None:
return self.handle_login_failure(self.source, "Could not determine id.")
# Get or create access record
defaults = {
"access_token": token.get("access_token"),
}
existing = UserOAuthSourceConnection.objects.filter(
source=self.source, identifier=identifier
)
if existing.exists():
connection = existing.first()
connection.access_token = token.get("access_token")
UserOAuthSourceConnection.objects.filter(pk=connection.pk).update(
**defaults
)
else:
connection = UserOAuthSourceConnection(
source=self.source,
identifier=identifier,
access_token=token.get("access_token"),
)
user = AuthorizedServiceBackend().authenticate(
source=self.source, identifier=identifier, request=request
if not self.source.enabled:
raise Http404(f"Source {slug} is not enabled.")
client = self.get_client(self.source)
callback = self.get_callback_url(self.source)
# Fetch access token
token = client.get_access_token(self.request, callback=callback)
if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.")
if "error" in token:
return self.handle_login_failure(self.source, token["error"])
# Fetch profile info
info = client.get_profile_info(token)
if info is None:
return self.handle_login_failure(self.source, "Could not retrieve profile.")
identifier = self.get_user_id(self.source, info)
if identifier is None:
return self.handle_login_failure(self.source, "Could not determine id.")
# Get or create access record
defaults = {
"access_token": token.get("access_token"),
}
existing = UserOAuthSourceConnection.objects.filter(
source=self.source, identifier=identifier
)
if existing.exists():
connection = existing.first()
connection.access_token = token.get("access_token")
UserOAuthSourceConnection.objects.filter(pk=connection.pk).update(
**defaults
)
if user is None:
if self.request.user.is_authenticated:
LOGGER.debug("Linking existing user", source=self.source)
return self.handle_existing_user_link(self.source, connection, info)
LOGGER.debug("Handling enrollment of new user", source=self.source)
return self.handle_enroll(self.source, connection, info)
LOGGER.debug("Handling existing user", source=self.source)
return self.handle_existing_user(self.source, user, connection, info)
else:
connection = UserOAuthSourceConnection(
source=self.source,
identifier=identifier,
access_token=token.get("access_token"),
)
user = AuthorizedServiceBackend().authenticate(
source=self.source, identifier=identifier, request=request
)
if user is None:
if self.request.user.is_authenticated:
LOGGER.debug("Linking existing user", source=self.source)
return self.handle_existing_user_link(self.source, connection, info)
LOGGER.debug("Handling enrollment of new user", source=self.source)
return self.handle_enroll(self.source, connection, info)
LOGGER.debug("Handling existing user", source=self.source)
return self.handle_existing_user(self.source, user, connection, info)
# pylint: disable=unused-argument
def get_callback_url(self, source: OAuthSource) -> str: