move can_handle calls to binding endpoints (/login/ and /login/initiate/), so that /login/authorize/ works either way, can clean up the session and audit
		
			
				
	
	
		
			229 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			229 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Basic SAML Processor"""
 | 
						|
from typing import TYPE_CHECKING, Dict, List, Union
 | 
						|
 | 
						|
from defusedxml import ElementTree
 | 
						|
from django.http import HttpRequest
 | 
						|
from structlog import get_logger
 | 
						|
 | 
						|
from passbook.core.exceptions import PropertyMappingExpressionException
 | 
						|
from passbook.providers.saml.exceptions import CannotHandleAssertion
 | 
						|
from passbook.providers.saml.processors.types import SAMLResponseParams
 | 
						|
from passbook.providers.saml.utils import get_random_id
 | 
						|
from passbook.providers.saml.utils.encoding import decode_base64_and_inflate, nice64
 | 
						|
from passbook.providers.saml.utils.time import get_time_string, timedelta_from_string
 | 
						|
from passbook.providers.saml.utils.xml_render import get_assertion_xml, get_response_xml
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from passbook.providers.saml.models import SAMLProvider
 | 
						|
 | 
						|
# pylint: disable=too-many-instance-attributes
 | 
						|
class Processor:
 | 
						|
    """Base SAML 2.0 AuthnRequest to Response Processor.
 | 
						|
    Sub-classes should provide Service Provider-specific functionality."""
 | 
						|
 | 
						|
    is_idp_initiated = False
 | 
						|
 | 
						|
    _remote: "SAMLProvider"
 | 
						|
    _http_request: HttpRequest
 | 
						|
 | 
						|
    _assertion_xml: str
 | 
						|
    _response_xml: str
 | 
						|
    _saml_response: str
 | 
						|
 | 
						|
    _relay_state: str
 | 
						|
    _saml_request: str
 | 
						|
 | 
						|
    _assertion_params: Dict[str, Union[str, List[Dict[str, str]]]]
 | 
						|
    _request_params: Dict[str, str]
 | 
						|
    _response_params: Dict[str, str]
 | 
						|
 | 
						|
    @property
 | 
						|
    def subject_format(self) -> str:
 | 
						|
        """Get subject Format"""
 | 
						|
        return "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
 | 
						|
 | 
						|
    def __init__(self, remote: "SAMLProvider"):
 | 
						|
        self.name = remote.name
 | 
						|
        self._remote = remote
 | 
						|
        self._logger = get_logger()
 | 
						|
 | 
						|
    def _build_assertion(self):
 | 
						|
        """Builds _assertion_params."""
 | 
						|
        self._assertion_params = {
 | 
						|
            "ASSERTION_ID": get_random_id(),
 | 
						|
            "ASSERTION_SIGNATURE": "",  # it's unsigned
 | 
						|
            "AUDIENCE": self._remote.audience,
 | 
						|
            "AUTH_INSTANT": get_time_string(),
 | 
						|
            "ISSUE_INSTANT": get_time_string(),
 | 
						|
            "NOT_BEFORE": get_time_string(
 | 
						|
                timedelta_from_string(self._remote.assertion_valid_not_before)
 | 
						|
            ),
 | 
						|
            "NOT_ON_OR_AFTER": get_time_string(
 | 
						|
                timedelta_from_string(self._remote.assertion_valid_not_on_or_after)
 | 
						|
            ),
 | 
						|
            "SESSION_INDEX": self._http_request.session.session_key,
 | 
						|
            "SESSION_NOT_ON_OR_AFTER": get_time_string(
 | 
						|
                timedelta_from_string(self._remote.session_valid_not_on_or_after)
 | 
						|
            ),
 | 
						|
            "SP_NAME_QUALIFIER": self._remote.audience,
 | 
						|
            "SUBJECT": self._http_request.user.email,
 | 
						|
            "SUBJECT_FORMAT": self.subject_format,
 | 
						|
            "ISSUER": self._remote.issuer,
 | 
						|
        }
 | 
						|
        self._assertion_params.update(self._request_params)
 | 
						|
 | 
						|
    def _build_response(self):
 | 
						|
        """Builds _response_params."""
 | 
						|
        self._response_params = {
 | 
						|
            "ASSERTION": self._assertion_xml,
 | 
						|
            "ISSUE_INSTANT": get_time_string(),
 | 
						|
            "RESPONSE_ID": get_random_id(),
 | 
						|
            "RESPONSE_SIGNATURE": "",  # initially unsigned
 | 
						|
            "ISSUER": self._remote.issuer,
 | 
						|
        }
 | 
						|
        self._response_params.update(self._request_params)
 | 
						|
 | 
						|
    def _encode_response(self):
 | 
						|
        """Encodes _response_xml to _encoded_xml."""
 | 
						|
        self._saml_response = nice64(str.encode(self._response_xml))
 | 
						|
 | 
						|
    def _extract_saml_request(self):
 | 
						|
        """Retrieves the _saml_request AuthnRequest from the _http_request."""
 | 
						|
        self._saml_request = self._http_request.session["SAMLRequest"]
 | 
						|
        self._relay_state = self._http_request.session["RelayState"]
 | 
						|
 | 
						|
    def _format_assertion(self):
 | 
						|
        """Formats _assertion_params as _assertion_xml."""
 | 
						|
        # https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions
 | 
						|
        attributes = []
 | 
						|
        from passbook.providers.saml.models import SAMLPropertyMapping
 | 
						|
 | 
						|
        for mapping in self._remote.property_mappings.all().select_subclasses():
 | 
						|
            if not isinstance(mapping, SAMLPropertyMapping):
 | 
						|
                continue
 | 
						|
            try:
 | 
						|
                mapping: SAMLPropertyMapping
 | 
						|
                value = mapping.evaluate(
 | 
						|
                    user=self._http_request.user,
 | 
						|
                    request=self._http_request,
 | 
						|
                    provider=self._remote,
 | 
						|
                )
 | 
						|
                mapping_payload = {
 | 
						|
                    "Name": mapping.saml_name,
 | 
						|
                    "FriendlyName": mapping.friendly_name,
 | 
						|
                }
 | 
						|
                # Normal values and arrays need different dict keys as they are handeled
 | 
						|
                # differently in the template
 | 
						|
                if isinstance(value, list):
 | 
						|
                    mapping_payload["ValueArray"] = value
 | 
						|
                else:
 | 
						|
                    mapping_payload["Value"] = value
 | 
						|
                attributes.append(mapping_payload)
 | 
						|
            except PropertyMappingExpressionException as exc:
 | 
						|
                self._logger.warning(exc)
 | 
						|
                continue
 | 
						|
        self._assertion_params["ATTRIBUTES"] = attributes
 | 
						|
        self._assertion_xml = get_assertion_xml(
 | 
						|
            "saml/xml/assertions/generic.xml", self._assertion_params, signed=True
 | 
						|
        )
 | 
						|
 | 
						|
    def _format_response(self):
 | 
						|
        """Formats _response_params as _response_xml."""
 | 
						|
        assertion_id = self._assertion_params["ASSERTION_ID"]
 | 
						|
        self._response_xml = get_response_xml(
 | 
						|
            self._response_params, saml_provider=self._remote, assertion_id=assertion_id
 | 
						|
        )
 | 
						|
 | 
						|
    def _get_saml_response_params(self) -> SAMLResponseParams:
 | 
						|
        """Returns a dictionary of parameters for the response template."""
 | 
						|
        return SAMLResponseParams(
 | 
						|
            acs_url=self._request_params["ACS_URL"],
 | 
						|
            saml_response=self._saml_response,
 | 
						|
            relay_state=self._relay_state,
 | 
						|
        )
 | 
						|
 | 
						|
    def _decode_and_parse_request(self):
 | 
						|
        """Parses various parameters from _request_xml into _request_params."""
 | 
						|
        decoded_xml = decode_base64_and_inflate(self._saml_request)
 | 
						|
 | 
						|
        root = ElementTree.fromstring(decoded_xml)
 | 
						|
 | 
						|
        params = {}
 | 
						|
        params["ACS_URL"] = root.attrib.get(
 | 
						|
            "AssertionConsumerServiceURL", self._remote.acs_url
 | 
						|
        )
 | 
						|
        params["REQUEST_ID"] = root.attrib["ID"]
 | 
						|
        params["DESTINATION"] = root.attrib.get("Destination", "")
 | 
						|
        params["PROVIDER_NAME"] = root.attrib.get("ProviderName", "")
 | 
						|
        self._request_params = params
 | 
						|
 | 
						|
    def _validate_request(self):
 | 
						|
        """
 | 
						|
        Validates the SAML request against the SP configuration of this
 | 
						|
        processor. Sub-classes should override this and raise a
 | 
						|
        `CannotHandleAssertion` exception if the validation fails.
 | 
						|
 | 
						|
        Raises:
 | 
						|
            CannotHandleAssertion: if the ACS URL specified in the SAML request
 | 
						|
                doesn't match the one specified in the processor config.
 | 
						|
        """
 | 
						|
        request_acs_url = self._request_params["ACS_URL"]
 | 
						|
 | 
						|
        if self._remote.acs_url != request_acs_url:
 | 
						|
            msg = (
 | 
						|
                f"ACS URL of {request_acs_url} doesn't match Provider "
 | 
						|
                f"ACS URL of {self._remote.acs_url}."
 | 
						|
            )
 | 
						|
            self._logger.info(msg)
 | 
						|
            raise CannotHandleAssertion(msg)
 | 
						|
 | 
						|
    def can_handle(self, request: HttpRequest) -> bool:
 | 
						|
        """Returns true if this processor can handle this request."""
 | 
						|
        self._http_request = request
 | 
						|
        # Read the request.
 | 
						|
        try:
 | 
						|
            self._extract_saml_request()
 | 
						|
        except KeyError:
 | 
						|
            raise CannotHandleAssertion(f"Couldn't find SAML request in user session")
 | 
						|
 | 
						|
        try:
 | 
						|
            self._decode_and_parse_request()
 | 
						|
        except Exception as exc:
 | 
						|
            raise CannotHandleAssertion(f"Couldn't parse SAML request: {exc}") from exc
 | 
						|
 | 
						|
        self._validate_request()
 | 
						|
        return True
 | 
						|
 | 
						|
    def generate_response(self) -> SAMLResponseParams:
 | 
						|
        """Processes request and returns template variables suitable for a response."""
 | 
						|
        # Build the assertion and response.
 | 
						|
        # Only call can_handle if SP initiated Request, otherwise we have no Request
 | 
						|
        if not self.is_idp_initiated:
 | 
						|
            self.can_handle(self._http_request)
 | 
						|
 | 
						|
        self._build_assertion()
 | 
						|
        self._format_assertion()
 | 
						|
        self._build_response()
 | 
						|
        self._format_response()
 | 
						|
        self._encode_response()
 | 
						|
 | 
						|
        # Return proper template params.
 | 
						|
        return self._get_saml_response_params()
 | 
						|
 | 
						|
    def init_deep_link(self, request: HttpRequest):
 | 
						|
        """Initialize this Processor to make an IdP-initiated call to the SP's
 | 
						|
        deep-linked URL."""
 | 
						|
        self._http_request = request
 | 
						|
        acs_url = self._remote.acs_url
 | 
						|
        # NOTE: The following request params are made up. Some are blank,
 | 
						|
        # because they comes over in the AuthnRequest, but we don't have an
 | 
						|
        # AuthnRequest in this case:
 | 
						|
        # - Destination: Should be this IdP's SSO endpoint URL. Not used in the response?
 | 
						|
        # - ProviderName: According to the spec, this is optional.
 | 
						|
        self._request_params = {
 | 
						|
            "ACS_URL": acs_url,
 | 
						|
            "DESTINATION": "",
 | 
						|
            "PROVIDER_NAME": "",
 | 
						|
        }
 | 
						|
        self._relay_state = ""
 |