root: check remote IP for proxy protocol same as HTTP/etc (#12094)
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -65,7 +65,7 @@ func (ls *LDAPServer) StartLDAPServer() error { | ||||
| 		ls.log.WithField("listen", listen).WithError(err).Warning("Failed to listen (SSL)") | ||||
| 		return err | ||||
| 	} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: ln} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()} | ||||
| 	defer proxyListener.Close() | ||||
|  | ||||
| 	ls.log.WithField("listen", listen).Info("Starting LDAP server") | ||||
|  | ||||
| @ -48,7 +48,7 @@ func (ls *LDAPServer) StartLDAPTLSServer() error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	proxyListener := &proxyproto.Listener{Listener: ln} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()} | ||||
| 	defer proxyListener.Close() | ||||
|  | ||||
| 	tln := tls.NewListener(proxyListener, tlsConfig) | ||||
|  | ||||
| @ -129,7 +129,7 @@ func (ps *ProxyServer) ServeHTTP() { | ||||
| 		ps.log.WithField("listen", listenAddress).WithError(err).Warning("Failed to listen") | ||||
| 		return | ||||
| 	} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: listener} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: listener, ConnPolicy: utils.GetProxyConnectionPolicy()} | ||||
| 	defer proxyListener.Close() | ||||
|  | ||||
| 	ps.log.WithField("listen", listenAddress).Info("Starting HTTP server") | ||||
| @ -148,7 +148,7 @@ func (ps *ProxyServer) ServeHTTPS() { | ||||
| 		ps.log.WithError(err).Warning("Failed to listen (TLS)") | ||||
| 		return | ||||
| 	} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()} | ||||
| 	defer proxyListener.Close() | ||||
|  | ||||
| 	tlsListener := tls.NewListener(proxyListener, tlsConfig) | ||||
|  | ||||
							
								
								
									
										34
									
								
								internal/utils/proxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								internal/utils/proxy.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | ||||
| package utils | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
|  | ||||
| 	"github.com/pires/go-proxyproto" | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"goauthentik.io/internal/config" | ||||
| ) | ||||
|  | ||||
| func GetProxyConnectionPolicy() proxyproto.ConnPolicyFunc { | ||||
| 	nets := []*net.IPNet{} | ||||
| 	for _, rn := range config.Get().Listen.TrustedProxyCIDRs { | ||||
| 		_, cidr, err := net.ParseCIDR(rn) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		nets = append(nets, cidr) | ||||
| 	} | ||||
| 	return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { | ||||
| 		host, _, err := net.SplitHostPort(connPolicyOptions.Upstream.String()) | ||||
| 		if err == nil { | ||||
| 			// remoteAddr will be nil if the IP cannot be parsed | ||||
| 			remoteAddr := net.ParseIP(host) | ||||
| 			for _, allowedCidr := range nets { | ||||
| 				if remoteAddr != nil && allowedCidr.Contains(remoteAddr) { | ||||
| 					log.WithField("remoteAddr", remoteAddr).WithField("cidr", allowedCidr.String()).Trace("Using remote IP from proxy protocol") | ||||
| 					return proxyproto.USE, nil | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return proxyproto.SKIP, nil | ||||
| 	} | ||||
| } | ||||
| @ -19,6 +19,7 @@ import ( | ||||
| 	"goauthentik.io/internal/config" | ||||
| 	"goauthentik.io/internal/gounicorn" | ||||
| 	"goauthentik.io/internal/outpost/proxyv2" | ||||
| 	"goauthentik.io/internal/utils" | ||||
| 	"goauthentik.io/internal/utils/web" | ||||
| 	"goauthentik.io/internal/web/brand_tls" | ||||
| ) | ||||
| @ -149,7 +150,7 @@ func (ws *WebServer) listenPlain() { | ||||
| 		ws.log.WithError(err).Warning("failed to listen") | ||||
| 		return | ||||
| 	} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: ln} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()} | ||||
| 	defer proxyListener.Close() | ||||
|  | ||||
| 	ws.log.WithField("listen", config.Get().Listen.HTTP).Info("Starting HTTP server") | ||||
|  | ||||
| @ -45,7 +45,7 @@ func (ws *WebServer) listenTLS() { | ||||
| 		ws.log.WithError(err).Warning("failed to listen (TLS)") | ||||
| 		return | ||||
| 	} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}} | ||||
| 	proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()} | ||||
| 	defer proxyListener.Close() | ||||
| 
 | ||||
| 	tlsListener := tls.NewListener(proxyListener, tlsConfig) | ||||
		Reference in New Issue
	
	Block a user
	 Jens L.
					Jens L.