root: check remote IP for proxy protocol same as HTTP/etc (#12094) Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							f535a23c03
						
					
				
				
					commit
					05f4e738a1
				
			@ -65,7 +65,7 @@ func (ls *LDAPServer) StartLDAPServer() error {
 | 
			
		||||
		ls.log.WithField("listen", listen).WithError(err).Warning("Failed to listen (SSL)")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: ln}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()}
 | 
			
		||||
	defer proxyListener.Close()
 | 
			
		||||
 | 
			
		||||
	ls.log.WithField("listen", listen).Info("Starting LDAP server")
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ func (ls *LDAPServer) StartLDAPTLSServer() error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: ln}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()}
 | 
			
		||||
	defer proxyListener.Close()
 | 
			
		||||
 | 
			
		||||
	tln := tls.NewListener(proxyListener, tlsConfig)
 | 
			
		||||
 | 
			
		||||
@ -129,7 +129,7 @@ func (ps *ProxyServer) ServeHTTP() {
 | 
			
		||||
		ps.log.WithField("listen", listenAddress).WithError(err).Warning("Failed to listen")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: listener}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: listener, ConnPolicy: utils.GetProxyConnectionPolicy()}
 | 
			
		||||
	defer proxyListener.Close()
 | 
			
		||||
 | 
			
		||||
	ps.log.WithField("listen", listenAddress).Info("Starting HTTP server")
 | 
			
		||||
@ -148,7 +148,7 @@ func (ps *ProxyServer) ServeHTTPS() {
 | 
			
		||||
		ps.log.WithError(err).Warning("Failed to listen (TLS)")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()}
 | 
			
		||||
	defer proxyListener.Close()
 | 
			
		||||
 | 
			
		||||
	tlsListener := tls.NewListener(proxyListener, tlsConfig)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								internal/utils/proxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								internal/utils/proxy.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
			
		||||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
 | 
			
		||||
	"github.com/pires/go-proxyproto"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"goauthentik.io/internal/config"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetProxyConnectionPolicy() proxyproto.ConnPolicyFunc {
 | 
			
		||||
	nets := []*net.IPNet{}
 | 
			
		||||
	for _, rn := range config.Get().Listen.TrustedProxyCIDRs {
 | 
			
		||||
		_, cidr, err := net.ParseCIDR(rn)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		nets = append(nets, cidr)
 | 
			
		||||
	}
 | 
			
		||||
	return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
 | 
			
		||||
		host, _, err := net.SplitHostPort(connPolicyOptions.Upstream.String())
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			// remoteAddr will be nil if the IP cannot be parsed
 | 
			
		||||
			remoteAddr := net.ParseIP(host)
 | 
			
		||||
			for _, allowedCidr := range nets {
 | 
			
		||||
				if remoteAddr != nil && allowedCidr.Contains(remoteAddr) {
 | 
			
		||||
					log.WithField("remoteAddr", remoteAddr).WithField("cidr", allowedCidr.String()).Trace("Using remote IP from proxy protocol")
 | 
			
		||||
					return proxyproto.USE, nil
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return proxyproto.SKIP, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -19,6 +19,7 @@ import (
 | 
			
		||||
	"goauthentik.io/internal/config"
 | 
			
		||||
	"goauthentik.io/internal/gounicorn"
 | 
			
		||||
	"goauthentik.io/internal/outpost/proxyv2"
 | 
			
		||||
	"goauthentik.io/internal/utils"
 | 
			
		||||
	"goauthentik.io/internal/utils/web"
 | 
			
		||||
	"goauthentik.io/internal/web/brand_tls"
 | 
			
		||||
)
 | 
			
		||||
@ -149,7 +150,7 @@ func (ws *WebServer) listenPlain() {
 | 
			
		||||
		ws.log.WithError(err).Warning("failed to listen")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: ln}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()}
 | 
			
		||||
	defer proxyListener.Close()
 | 
			
		||||
 | 
			
		||||
	ws.log.WithField("listen", config.Get().Listen.HTTP).Info("Starting HTTP server")
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,7 @@ func (ws *WebServer) listenTLS() {
 | 
			
		||||
		ws.log.WithError(err).Warning("failed to listen (TLS)")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}}
 | 
			
		||||
	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()}
 | 
			
		||||
	defer proxyListener.Close()
 | 
			
		||||
 | 
			
		||||
	tlsListener := tls.NewListener(proxyListener, tlsConfig)
 | 
			
		||||
		Reference in New Issue
	
	Block a user