118 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			118 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """SAML AuthNRequest Parser and dataclass"""
 | |
| from base64 import b64decode
 | |
| from dataclasses import dataclass
 | |
| from typing import Optional
 | |
| from urllib.parse import quote_plus
 | |
| 
 | |
| from cryptography.exceptions import InvalidSignature
 | |
| from cryptography.hazmat.primitives import hashes
 | |
| from cryptography.hazmat.primitives.asymmetric import padding
 | |
| from defusedxml import ElementTree
 | |
| from signxml import XMLVerifier
 | |
| from structlog import get_logger
 | |
| 
 | |
| from passbook.providers.saml.exceptions import CannotHandleAssertion
 | |
| from passbook.providers.saml.models import SAMLProvider
 | |
| from passbook.providers.saml.utils.encoding import decode_base64_and_inflate
 | |
| from passbook.sources.saml.processors.constants import (
 | |
|     NS_SAML_PROTOCOL,
 | |
|     SAML_NAME_ID_FORMAT_EMAIL,
 | |
| )
 | |
| 
 | |
| LOGGER = get_logger()
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class AuthNRequest:
 | |
|     """AuthNRequest Dataclass"""
 | |
| 
 | |
|     # pylint: disable=invalid-name
 | |
|     id: Optional[str] = None
 | |
| 
 | |
|     relay_state: Optional[str] = None
 | |
| 
 | |
|     name_id_policy: str = SAML_NAME_ID_FORMAT_EMAIL
 | |
| 
 | |
| 
 | |
| class AuthNRequestParser:
 | |
|     """AuthNRequest Parser"""
 | |
| 
 | |
|     provider: SAMLProvider
 | |
| 
 | |
|     def __init__(self, provider: SAMLProvider):
 | |
|         self.provider = provider
 | |
| 
 | |
|     def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest:
 | |
|         root = ElementTree.fromstring(decoded_xml)
 | |
| 
 | |
|         request_acs_url = root.attrib["AssertionConsumerServiceURL"]
 | |
| 
 | |
|         if self.provider.acs_url.lower() != request_acs_url.lower():
 | |
|             msg = (
 | |
|                 f"ACS URL of {request_acs_url} doesn't match Provider "
 | |
|                 f"ACS URL of {self.provider.acs_url}."
 | |
|             )
 | |
|             LOGGER.info(msg)
 | |
|             raise CannotHandleAssertion(msg)
 | |
| 
 | |
|         auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state)
 | |
| 
 | |
|         # Check if AuthnRequest has a NameID Policy object
 | |
|         name_id_policies = root.findall(f"{{{NS_SAML_PROTOCOL}}}:NameIDPolicy")
 | |
|         if len(name_id_policies) > 0:
 | |
|             name_id_policy = name_id_policies[0]
 | |
|             auth_n_request.name_id_policy = name_id_policy.attrib["Format"]
 | |
| 
 | |
|         return auth_n_request
 | |
| 
 | |
|     def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest:
 | |
|         """Validate and parse raw request with enveloped signautre."""
 | |
|         decoded_xml = decode_base64_and_inflate(saml_request)
 | |
| 
 | |
|         if self.provider.signing_kp:
 | |
|             try:
 | |
|                 XMLVerifier().verify(
 | |
|                     decoded_xml, x509_cert=self.provider.signing_kp.certificate_data
 | |
|                 )
 | |
|             except InvalidSignature as exc:
 | |
|                 raise CannotHandleAssertion("Failed to verify signature") from exc
 | |
| 
 | |
|         return self._parse_xml(decoded_xml, relay_state)
 | |
| 
 | |
|     def parse_detached(
 | |
|         self,
 | |
|         saml_request: str,
 | |
|         relay_state: Optional[str],
 | |
|         signature: Optional[str] = None,
 | |
|         sig_alg: Optional[str] = None,
 | |
|     ) -> AuthNRequest:
 | |
|         """Validate and parse raw request with detached signature"""
 | |
|         decoded_xml = decode_base64_and_inflate(saml_request)
 | |
| 
 | |
|         if signature and sig_alg:
 | |
|             # if sig_alg == "http://www.w3.org/2000/09/xmldsig#rsa-sha1":
 | |
|             sig_hash = hashes.SHA1()  # nosec
 | |
| 
 | |
|             querystring = f"SAMLRequest={quote_plus(saml_request)}&"
 | |
|             if relay_state is not None:
 | |
|                 querystring += f"RelayState={quote_plus(relay_state)}&"
 | |
|             querystring += f"SigAlg={sig_alg}"
 | |
| 
 | |
|             public_key = self.provider.signing_kp.private_key.public_key()
 | |
|             try:
 | |
|                 public_key.verify(
 | |
|                     b64decode(signature),
 | |
|                     querystring.encode(),
 | |
|                     padding.PSS(
 | |
|                         mgf=padding.MGF1(sig_hash), salt_length=padding.PSS.MAX_LENGTH
 | |
|                     ),
 | |
|                     sig_hash,
 | |
|                 )
 | |
|             except InvalidSignature as exc:
 | |
|                 raise CannotHandleAssertion("Failed to verify signature") from exc
 | |
|         return self._parse_xml(decoded_xml, relay_state)
 | |
| 
 | |
|     def idp_initiated(self) -> AuthNRequest:
 | |
|         """Create IdP Initiated AuthNRequest"""
 | |
|         return AuthNRequest()
 | 
