84 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""SAML Request Parse/builder"""
 | 
						|
from typing import TYPE_CHECKING, Optional
 | 
						|
 | 
						|
from defusedxml import ElementTree
 | 
						|
from signxml import XMLVerifier
 | 
						|
 | 
						|
from passbook.crypto.models import CertificateKeyPair
 | 
						|
from passbook.providers.samlv2.saml.constants import (
 | 
						|
    NS_SAML_ASSERTION,
 | 
						|
    NS_SAML_PROTOCOL,
 | 
						|
    SAML_ATTRIB_ACS_URL,
 | 
						|
    SAML_ATTRIB_DESTINATION,
 | 
						|
    SAML_ATTRIB_ID,
 | 
						|
    SAML_ATTRIB_ISSUE_INSTANT,
 | 
						|
    SAML_ATTRIB_PROTOCOL_BINDING,
 | 
						|
)
 | 
						|
from passbook.providers.samlv2.saml.utils import decode_base64_and_inflate
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from xml.etree.ElementTree import Element  # nosec
 | 
						|
 | 
						|
 | 
						|
# pylint: disable=too-many-instance-attributes
 | 
						|
class SAMLRequest:
 | 
						|
    """SAML Request data class, parse raw base64-encoded data, checks signature and more"""
 | 
						|
 | 
						|
    _root: "Element"
 | 
						|
 | 
						|
    acs_url: str
 | 
						|
    destination: str
 | 
						|
    id: str
 | 
						|
    issue_instant: str
 | 
						|
    protocol_binding: str
 | 
						|
 | 
						|
    issuer: str
 | 
						|
 | 
						|
    is_signed: bool
 | 
						|
    _detached_signature: str
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self.acs_url = ""
 | 
						|
        self.destination = ""
 | 
						|
        # pylint: disable=invalid-name
 | 
						|
        self.id = ""
 | 
						|
        self.issue_instant = ""
 | 
						|
        self.protocol_binding = ""
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def parse(raw: str, detached_signature: Optional[str] = None) -> "SAMLRequest":
 | 
						|
        """Prase SAML request from raw string, which can be base64 encoded and deflated.
 | 
						|
        Optionally accepts a detached_signature, as from a REDIRECT request."""
 | 
						|
        decoded_xml = decode_base64_and_inflate(raw)
 | 
						|
        root = ElementTree.fromstring(decoded_xml)
 | 
						|
        req = SAMLRequest()
 | 
						|
        req._root = root  # pylint: disable=protected-access
 | 
						|
        # Verify the root element's tag
 | 
						|
        _expected_tag = f"{{{NS_SAML_PROTOCOL}}}AuthnRequest"
 | 
						|
        if root.tag != _expected_tag:
 | 
						|
            raise ValueError(
 | 
						|
                f"Invalid root tag, got '{root.tag}', expected '{_expected_tag}."
 | 
						|
            )
 | 
						|
        req.acs_url = root.attrib[SAML_ATTRIB_ACS_URL]
 | 
						|
        req.destination = root.attrib[SAML_ATTRIB_DESTINATION]
 | 
						|
        req.id = root.attrib[SAML_ATTRIB_ID]
 | 
						|
        req.issue_instant = root.attrib[SAML_ATTRIB_ISSUE_INSTANT]
 | 
						|
        req.protocol_binding = root.attrib[SAML_ATTRIB_PROTOCOL_BINDING]
 | 
						|
        req.issuer = root.find(f"{{{NS_SAML_ASSERTION}}}Issuer").text
 | 
						|
        # Check if this Request is signed
 | 
						|
        if detached_signature:
 | 
						|
            # pylint: disable=protected-access
 | 
						|
            req._detached_signature = detached_signature
 | 
						|
        return req
 | 
						|
 | 
						|
    def verify_signature(self, keypair: CertificateKeyPair):
 | 
						|
        """Verify signature of SAML Request.
 | 
						|
        Raises `cryptography.exceptions.InvalidSignature` on validaton failure."""
 | 
						|
        verifier = XMLVerifier()
 | 
						|
        if self._detached_signature:
 | 
						|
            verifier.verify(
 | 
						|
                self._detached_signature, x509_cert=keypair.certificate_data
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            verifier.verify(self._root, x509_cert=keypair.certificate_data)
 |