sources/oauth: cleanup clients, add type annotations

This commit is contained in:
Jens Langhammer
2020-09-26 00:34:57 +02:00
parent 6e4ce8dbaa
commit d9c2b32cba
8 changed files with 94 additions and 85 deletions

View File

@ -1,6 +1,8 @@
"""OAuth Base views"""
from typing import Optional, Type
from django.http.request import HttpRequest
from passbook.sources.oauth.clients.base import BaseOAuthClient
from passbook.sources.oauth.clients.oauth1 import OAuthClient
from passbook.sources.oauth.clients.oauth2 import OAuth2Client
@ -11,13 +13,15 @@ from passbook.sources.oauth.models import OAuthSource
class OAuthClientMixin:
"Mixin for getting OAuth client for a source."
request: HttpRequest # Set by View class
client_class: Optional[Type[BaseOAuthClient]] = None
def get_client(self, source: OAuthSource) -> BaseOAuthClient:
def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
"Get instance of the OAuth client for this source."
if self.client_class is not None:
# pylint: disable=not-callable
return self.client_class(source)
return self.client_class(source, self.request, **kwargs)
if source.request_token_url:
return OAuthClient(source)
return OAuth2Client(source)
return OAuthClient(source, self.request, **kwargs)
return OAuth2Client(source, self.request, **kwargs)

View File

@ -54,7 +54,7 @@ class OAuthCallback(OAuthClientMixin, View):
client = self.get_client(self.source)
callback = self.get_callback_url(self.source)
# Fetch access token
token = client.get_access_token(self.request, callback=callback)
token = client.get_access_token(callback=callback)
if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.")
if "error" in token:

View File

@ -40,9 +40,6 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
else:
if not source.enabled:
raise Http404(f"source {slug} is not enabled.")
client = self.get_client(source)
callback = self.get_callback_url(source)
client = self.get_client(source, callback=self.get_callback_url(source))
params = self.get_additional_parameters(source)
return client.get_redirect_url(
self.request, callback=callback, parameters=params
)
return client.get_redirect_url(params)