304 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			304 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Core OAauth Views"""
 | |
| from typing import Any, Callable, Dict, Optional
 | |
| 
 | |
| from django.conf import settings
 | |
| from django.contrib import messages
 | |
| from django.contrib.auth.mixins import LoginRequiredMixin
 | |
| from django.http import Http404, HttpRequest, HttpResponse
 | |
| from django.shortcuts import get_object_or_404, redirect, render
 | |
| from django.urls import reverse
 | |
| from django.utils.translation import ugettext as _
 | |
| from django.views.generic import RedirectView, View
 | |
| from structlog import get_logger
 | |
| 
 | |
| from passbook.audit.models import Event, EventAction
 | |
| from passbook.core.models import User
 | |
| from passbook.flows.models import Flow
 | |
| from passbook.flows.planner import (
 | |
|     PLAN_CONTEXT_PENDING_USER,
 | |
|     PLAN_CONTEXT_SSO,
 | |
|     FlowPlanner,
 | |
| )
 | |
| from passbook.flows.views import SESSION_KEY_PLAN
 | |
| from passbook.lib.utils.urls import redirect_with_qs
 | |
| from passbook.policies.utils import delete_none_keys
 | |
| from passbook.sources.oauth.auth import AuthorizedServiceBackend
 | |
| from passbook.sources.oauth.clients import BaseOAuthClient, get_client
 | |
| from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
 | |
| from passbook.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
 | |
| from passbook.stages.prompt.stage import PLAN_CONTEXT_PROMPT
 | |
| 
 | |
| LOGGER = get_logger()
 | |
| 
 | |
| 
 | |
| # pylint: disable=too-few-public-methods
 | |
| class OAuthClientMixin:
 | |
|     "Mixin for getting OAuth client for a source."
 | |
| 
 | |
|     client_class: Optional[Callable] = None
 | |
| 
 | |
|     def get_client(self, source: OAuthSource) -> BaseOAuthClient:
 | |
|         "Get instance of the OAuth client for this source."
 | |
|         if self.client_class is not None:
 | |
|             # pylint: disable=not-callable
 | |
|             return self.client_class(source)
 | |
|         return get_client(source)
 | |
| 
 | |
| 
 | |
| class OAuthRedirect(OAuthClientMixin, RedirectView):
 | |
|     "Redirect user to OAuth source to enable access."
 | |
| 
 | |
|     permanent = False
 | |
|     params = None
 | |
| 
 | |
|     # pylint: disable=unused-argument
 | |
|     def get_additional_parameters(self, source: OAuthSource) -> Dict[str, Any]:
 | |
|         "Return additional redirect parameters for this source."
 | |
|         return self.params or {}
 | |
| 
 | |
|     def get_callback_url(self, source: OAuthSource) -> str:
 | |
|         "Return the callback url for this source."
 | |
|         return reverse(
 | |
|             "passbook_sources_oauth:oauth-client-callback",
 | |
|             kwargs={"source_slug": source.slug},
 | |
|         )
 | |
| 
 | |
|     def get_redirect_url(self, **kwargs) -> str:
 | |
|         "Build redirect url for a given source."
 | |
|         slug = kwargs.get("source_slug", "")
 | |
|         try:
 | |
|             source = OAuthSource.objects.get(slug=slug)
 | |
|         except OAuthSource.DoesNotExist:
 | |
|             raise Http404(f"Unknown OAuth source '{slug}'.")
 | |
|         else:
 | |
|             if not source.enabled:
 | |
|                 raise Http404(f"source {slug} is not enabled.")
 | |
|             client = self.get_client(source)
 | |
|             callback = self.get_callback_url(source)
 | |
|             params = self.get_additional_parameters(source)
 | |
|             return client.get_redirect_url(
 | |
|                 self.request, callback=callback, parameters=params
 | |
|             )
 | |
| 
 | |
| 
 | |
| class OAuthCallback(OAuthClientMixin, View):
 | |
|     "Base OAuth callback view."
 | |
| 
 | |
|     source_id = None
 | |
|     source = None
 | |
| 
 | |
|     def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
 | |
|         """View Get handler"""
 | |
|         slug = kwargs.get("source_slug", "")
 | |
|         try:
 | |
|             self.source = OAuthSource.objects.get(slug=slug)
 | |
|         except OAuthSource.DoesNotExist:
 | |
|             raise Http404("Unknown OAuth source '%s'." % slug)
 | |
|         else:
 | |
|             if not self.source.enabled:
 | |
|                 raise Http404("source %s is not enabled." % slug)
 | |
|             client = self.get_client(self.source)
 | |
|             callback = self.get_callback_url(self.source)
 | |
|             # Fetch access token
 | |
|             token = client.get_access_token(self.request, callback=callback)
 | |
|             if token is None:
 | |
|                 return self.handle_login_failure(
 | |
|                     self.source, "Could not retrieve token."
 | |
|                 )
 | |
|             if "error" in token:
 | |
|                 return self.handle_login_failure(self.source, token["error"])
 | |
|             # Fetch profile info
 | |
|             info = client.get_profile_info(token)
 | |
|             if info is None:
 | |
|                 return self.handle_login_failure(
 | |
|                     self.source, "Could not retrieve profile."
 | |
|                 )
 | |
|             identifier = self.get_user_id(self.source, info)
 | |
|             if identifier is None:
 | |
|                 return self.handle_login_failure(self.source, "Could not determine id.")
 | |
|             # Get or create access record
 | |
|             defaults = {
 | |
|                 "access_token": token.get("access_token"),
 | |
|             }
 | |
|             existing = UserOAuthSourceConnection.objects.filter(
 | |
|                 source=self.source, identifier=identifier
 | |
|             )
 | |
| 
 | |
|             if existing.exists():
 | |
|                 connection = existing.first()
 | |
|                 connection.access_token = token.get("access_token")
 | |
|                 UserOAuthSourceConnection.objects.filter(pk=connection.pk).update(
 | |
|                     **defaults
 | |
|                 )
 | |
|             else:
 | |
|                 connection = UserOAuthSourceConnection(
 | |
|                     source=self.source,
 | |
|                     identifier=identifier,
 | |
|                     access_token=token.get("access_token"),
 | |
|                 )
 | |
|             user = AuthorizedServiceBackend().authenticate(
 | |
|                 source=self.source, identifier=identifier, request=request
 | |
|             )
 | |
|             if user is None:
 | |
|                 LOGGER.debug("Handling new connection", source=self.source)
 | |
|                 return self.handle_new_connection(self.source, connection, info)
 | |
|             LOGGER.debug("Handling existing user", source=self.source)
 | |
|             return self.handle_existing_user(self.source, user, connection, info)
 | |
| 
 | |
|     # pylint: disable=unused-argument
 | |
|     def get_callback_url(self, source: OAuthSource) -> str:
 | |
|         "Return callback url if different than the current url."
 | |
|         return ""
 | |
| 
 | |
|     # pylint: disable=unused-argument
 | |
|     def get_error_redirect(self, source: OAuthSource, reason: str) -> str:
 | |
|         "Return url to redirect on login failure."
 | |
|         return settings.LOGIN_URL
 | |
| 
 | |
|     def get_user_enroll_context(
 | |
|         self,
 | |
|         source: OAuthSource,
 | |
|         access: UserOAuthSourceConnection,
 | |
|         info: Dict[str, Any],
 | |
|     ) -> Dict[str, Any]:
 | |
|         """Create a dict of User data"""
 | |
|         raise NotImplementedError()
 | |
| 
 | |
|     # pylint: disable=unused-argument
 | |
|     def get_user_id(
 | |
|         self, source: UserOAuthSourceConnection, info: Dict[str, Any]
 | |
|     ) -> Optional[str]:
 | |
|         """Return unique identifier from the profile info."""
 | |
|         if "id" in info:
 | |
|             return info["id"]
 | |
|         return None
 | |
| 
 | |
|     def handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse:
 | |
|         """Prepare Authentication Plan, redirect user FlowExecutor"""
 | |
|         kwargs.update(
 | |
|             {
 | |
|                 # Since we authenticate the user by their token, they have no backend set
 | |
|                 PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
 | |
|                 PLAN_CONTEXT_SSO: True,
 | |
|             }
 | |
|         )
 | |
|         # We run the Flow planner here so we can pass the Pending user in the context
 | |
|         planner = FlowPlanner(flow)
 | |
|         plan = planner.plan(self.request, kwargs,)
 | |
|         self.request.session[SESSION_KEY_PLAN] = plan
 | |
|         return redirect_with_qs(
 | |
|             "passbook_flows:flow-executor-shell", self.request.GET, flow_slug=flow.slug,
 | |
|         )
 | |
| 
 | |
|     # pylint: disable=unused-argument
 | |
|     def handle_existing_user(
 | |
|         self,
 | |
|         source: OAuthSource,
 | |
|         user: User,
 | |
|         access: UserOAuthSourceConnection,
 | |
|         info: Dict[str, Any],
 | |
|     ) -> HttpResponse:
 | |
|         "Login user and redirect."
 | |
|         messages.success(
 | |
|             self.request,
 | |
|             _(
 | |
|                 "Successfully authenticated with %(source)s!"
 | |
|                 % {"source": self.source.name}
 | |
|             ),
 | |
|         )
 | |
|         flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user}
 | |
|         return self.handle_login_flow(source.authentication_flow, **flow_kwargs)
 | |
| 
 | |
|     def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse:
 | |
|         "Message user and redirect on error."
 | |
|         LOGGER.warning("Authentication Failure", reason=reason)
 | |
|         messages.error(self.request, _("Authentication Failed."))
 | |
|         return redirect(self.get_error_redirect(source, reason))
 | |
| 
 | |
|     def handle_new_connection(
 | |
|         self,
 | |
|         source: OAuthSource,
 | |
|         access: UserOAuthSourceConnection,
 | |
|         info: Dict[str, Any],
 | |
|     ) -> HttpResponse:
 | |
|         """Check if a user exists for the connection and connect them, otherwise
 | |
|         prepare to enroll a new user."""
 | |
|         if self.request.user.is_authenticated:
 | |
|             # there's already a user logged in, just link them up
 | |
|             user = self.request.user
 | |
|             access.user = user
 | |
|             access.save()
 | |
|             UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
 | |
|             Event.new(
 | |
|                 EventAction.CUSTOM, message="Linked OAuth Source", source=source
 | |
|             ).from_http(self.request)
 | |
|             messages.success(
 | |
|                 self.request,
 | |
|                 _("Successfully linked %(source)s!" % {"source": self.source.name}),
 | |
|             )
 | |
|             return redirect(
 | |
|                 reverse(
 | |
|                     "passbook_sources_oauth:oauth-client-user",
 | |
|                     kwargs={"source_slug": self.source.slug},
 | |
|                 )
 | |
|             )
 | |
|         # User was not authenticated, new user will be created
 | |
|         messages.success(
 | |
|             self.request,
 | |
|             _(
 | |
|                 "Successfully authenticated with %(source)s!"
 | |
|                 % {"source": self.source.name}
 | |
|             ),
 | |
|         )
 | |
|         # Trim out all keys that have a value of None,
 | |
|         # so we use `"key" in ` checks in policies
 | |
|         context = {
 | |
|             PLAN_CONTEXT_PROMPT: delete_none_keys(
 | |
|                 self.get_user_enroll_context(source, access, info)
 | |
|             )
 | |
|         }
 | |
|         return self.handle_login_flow(source.enrollment_flow, **context)
 | |
| 
 | |
| 
 | |
| class DisconnectView(LoginRequiredMixin, View):
 | |
|     """Delete connection with source"""
 | |
| 
 | |
|     source = None
 | |
|     aas = None
 | |
| 
 | |
|     def dispatch(self, request, source_slug):
 | |
|         self.source = get_object_or_404(OAuthSource, slug=source_slug)
 | |
|         self.aas = get_object_or_404(
 | |
|             UserOAuthSourceConnection, source=self.source, user=request.user
 | |
|         )
 | |
|         return super().dispatch(request, source_slug)
 | |
| 
 | |
|     def post(self, request, source_slug):
 | |
|         """Delete connection object"""
 | |
|         if "confirmdelete" in request.POST:
 | |
|             # User confirmed deletion
 | |
|             self.aas.delete()
 | |
|             messages.success(request, _("Connection successfully deleted"))
 | |
|             return redirect(
 | |
|                 reverse(
 | |
|                     "passbook_sources_oauth:oauth-client-user",
 | |
|                     kwargs={"source_slug": self.source.slug},
 | |
|                 )
 | |
|             )
 | |
|         return self.get(request, source_slug)
 | |
| 
 | |
|     # pylint: disable=unused-argument
 | |
|     def get(self, request, source_slug):
 | |
|         """Show delete form"""
 | |
|         return render(
 | |
|             request,
 | |
|             "generic/delete.html",
 | |
|             {
 | |
|                 "object": self.source,
 | |
|                 "delete_url": reverse(
 | |
|                     "passbook_sources_oauth:oauth-client-disconnect",
 | |
|                     kwargs={"source_slug": self.source.slug},
 | |
|                 ),
 | |
|             },
 | |
|         )
 | 
