From 441916703d947753de0b0b978a473e5439d4f665 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Thu, 24 Oct 2024 00:56:10 +0200 Subject: [PATCH] implement adapter using outposts Signed-off-by: Jens Langhammer --- authentik/outposts/consumer.py | 7 +- authentik/outposts/http.py | 84 +++++++++++ authentik/outposts/models.py | 1 + authentik/providers/scim/clients/base.py | 13 +- cmd/scim/main.go | 174 +++++++++++++++++++++++ internal/outpost/ak/api.go | 2 +- internal/outpost/ak/api_ws.go | 16 ++- internal/outpost/ak/api_ws_msg.go | 12 +- 8 files changed, 293 insertions(+), 16 deletions(-) create mode 100644 authentik/outposts/http.py create mode 100644 cmd/scim/main.go diff --git a/authentik/outposts/consumer.py b/authentik/outposts/consumer.py index 80b64999d5..73c8852005 100644 --- a/authentik/outposts/consumer.py +++ b/authentik/outposts/consumer.py @@ -37,7 +37,6 @@ class WebsocketMessageInstruction(IntEnum): # Provider specific message PROVIDER_SPECIFIC = 3 - @dataclass(slots=True) class WebsocketMessage: """Complete Websocket Message that is being sent""" @@ -128,6 +127,12 @@ class OutpostConsumer(JsonWebsocketConsumer): state.args.update(msg.args) elif msg.instruction == WebsocketMessageInstruction.ACK: return + elif msg.instruction == WebsocketMessageInstruction.PROVIDER_SPECIFIC: + if "response_channel" not in msg.args: + return + self.logger.debug("Posted response to channel", msg=msg) + async_to_sync(self.channel_layer.send)(msg.args.get("response_channel"), content) + return GAUGE_OUTPOSTS_LAST_UPDATE.labels( tenant=connection.schema_name, outpost=self.outpost.name, diff --git a/authentik/outposts/http.py b/authentik/outposts/http.py new file mode 100644 index 0000000000..22273abc7e --- /dev/null +++ b/authentik/outposts/http.py @@ -0,0 +1,84 @@ +from base64 import b64decode +from dataclasses import asdict, dataclass +from random import choice +from typing import Any +from uuid import uuid4 + +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer +from channels_redis.pubsub import RedisPubSubChannelLayer +from requests.adapters import BaseAdapter +from requests.models import PreparedRequest, Response +from structlog.stdlib import get_logger +from requests.utils import CaseInsensitiveDict +from authentik.outposts.models import Outpost + + +@dataclass +class OutpostPreparedRequest: + uid: str + method: str + url: str + headers: dict[str, str] + body: Any + ssl_verify: bool + timeout: int + + @staticmethod + def from_requests(req: PreparedRequest) -> "OutpostPreparedRequest": + return OutpostPreparedRequest( + uid=str(uuid4()), + method=req.method, + url=req.url, + headers=req.headers._store, + body=req.body, + ssl_verify=True, + timeout=0, + ) + + @property + def response_channel(self) -> str: + return f"authentik_outpost_http_response_{self.uid}" + +class OutpostHTTPAdapter(BaseAdapter): + """Requests Adapter that sends HTTP requests via a specified Outpost""" + + def __init__(self, outpost: Outpost, default_timeout=10): + super().__init__() + self.__outpost = outpost + self.__logger = get_logger().bind() + self.__layer: RedisPubSubChannelLayer = get_channel_layer() + self.default_timeout = default_timeout + + def parse_response(self, raw_response: dict, req: PreparedRequest) -> Response: + res = Response() + res.request = req + res.status_code = raw_response.get("status") + res.url = raw_response.get("final_url") + res.headers = CaseInsensitiveDict(raw_response.get("headers")) + res._content = b64decode(raw_response.get("body")) + return res + + def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None): + # Convert request so we can send it to the outpost + converted = OutpostPreparedRequest.from_requests(request) + converted.ssl_verify = verify + converted.timeout = timeout if timeout else self.default_timeout + # Pick one of the outpost instances + state = choice(self.__outpost.state) + self.__logger.debug("sending HTTP request to outpost", uid=converted.uid) + async_to_sync(self.__layer.send)( + state.uid, + { + "type": "event.provider.specific", + "sub_type": "http_request", + "response_channel": converted.response_channel, + "request": asdict(converted), + }, + ) + self.__logger.debug("receiving HTTP response from outpost",uid=converted.uid) + raw_response = async_to_sync(self.__layer.receive)( + converted.response_channel, + ) + self.__logger.debug("received HTTP response from outpost",uid=converted.uid) + return self.parse_response(raw_response.get("args", {}).get("response", {}), request) diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 4032892fe8..69cbe0321b 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -98,6 +98,7 @@ class OutpostType(models.TextChoices): LDAP = "ldap" RADIUS = "radius" RAC = "rac" + SCIM = "scim" def default_outpost_config(host: str | None = None): diff --git a/authentik/providers/scim/clients/base.py b/authentik/providers/scim/clients/base.py index 246520114c..d9a78e1d66 100644 --- a/authentik/providers/scim/clients/base.py +++ b/authentik/providers/scim/clients/base.py @@ -19,6 +19,7 @@ from authentik.lib.sync.outgoing.exceptions import ( TransientSyncException, ) from authentik.lib.utils.http import get_http_session +from authentik.outposts.http import OutpostHTTPAdapter from authentik.providers.scim.clients.exceptions import SCIMRequestException from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMProvider @@ -41,8 +42,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"]( def __init__(self, provider: SCIMProvider): super().__init__(provider) - self._session = get_http_session() - self._session.verify = provider.verify_certificates + self._session = self.get_session(provider) self.provider = provider # Remove trailing slashes as we assume the URL doesn't have any base_url = provider.url @@ -52,6 +52,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"]( self.token = provider.token self._config = self.get_service_provider_config() + def get_session(self, provider: SCIMProvider): + session = get_http_session() + if self.provider.outpost_set.exists(): + adapter = OutpostHTTPAdapter() + session.mount("https://", adapter) + session.mount("http://", adapter) + session.verify = provider.verify_certificates + return session + def _request(self, method: str, path: str, **kwargs) -> dict: """Wrapper to send a request to the full URL""" try: diff --git a/cmd/scim/main.go b/cmd/scim/main.go new file mode 100644 index 0000000000..e5b7a79cfc --- /dev/null +++ b/cmd/scim/main.go @@ -0,0 +1,174 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "net/http" + "net/url" + "os" + + "github.com/mitchellh/mapstructure" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "goauthentik.io/internal/common" + "goauthentik.io/internal/debug" + "goauthentik.io/internal/outpost/ak" + "goauthentik.io/internal/outpost/ak/healthcheck" +) + +const helpMessage = `authentik SCIM + +Required environment variables: +- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company") +- AUTHENTIK_TOKEN: Token to authenticate with +- AUTHENTIK_INSECURE: Skip SSL Certificate verification` + +var rootCmd = &cobra.Command{ + Long: helpMessage, + PersistentPreRun: func(cmd *cobra.Command, args []string) { + log.SetLevel(log.DebugLevel) + log.SetFormatter(&log.JSONFormatter{ + FieldMap: log.FieldMap{ + log.FieldKeyMsg: "event", + log.FieldKeyTime: "timestamp", + }, + DisableHTMLEscape: true, + }) + }, + Run: func(cmd *cobra.Command, args []string) { + debug.EnableDebugServer() + akURL, found := os.LookupEnv("AUTHENTIK_HOST") + if !found { + fmt.Println("env AUTHENTIK_HOST not set!") + fmt.Println(helpMessage) + os.Exit(1) + } + akToken, found := os.LookupEnv("AUTHENTIK_TOKEN") + if !found { + fmt.Println("env AUTHENTIK_TOKEN not set!") + fmt.Println(helpMessage) + os.Exit(1) + } + + akURLActual, err := url.Parse(akURL) + if err != nil { + fmt.Println(err) + fmt.Println(helpMessage) + os.Exit(1) + } + + ex := common.Init() + defer common.Defer() + go func() { + for { + <-ex + os.Exit(0) + } + }() + + ac := ak.NewAPIController(*akURLActual, akToken) + if ac == nil { + os.Exit(1) + } + defer ac.Shutdown() + + ac.Server = &SCIMOutpost{ + ac: ac, + log: log.WithField("logger", "authentik.outpost.scim"), + } + + err = ac.Start() + if err != nil { + log.WithError(err).Panic("Failed to run server") + } + + for { + <-ex + } + }, +} + +type HTTPRequest struct { + Uid string `mapstructure:"uid"` + Method string `mapstructure:"method"` + URL string `mapstructure:"url"` + Headers map[string][]string `mapstructure:"headers"` + Body interface{} `mapstructure:"body"` + SSLVerify bool `mapstructure:"ssl_verify"` + Timeout int `mapstructure:"timeout"` +} + +type RequestArgs struct { + Request HTTPRequest `mapstructure:"request"` + ResponseChannel string `mapstructure:"response_channel"` +} + +type SCIMOutpost struct { + ac *ak.APIController + log *log.Entry +} + +func (s *SCIMOutpost) Type() string { return "SCIM" } +func (s *SCIMOutpost) Stop() error { return nil } +func (s *SCIMOutpost) Refresh() error { return nil } +func (s *SCIMOutpost) TimerFlowCacheExpiry(context.Context) {} + +func (s *SCIMOutpost) Start() error { + s.ac.AddWSHandler(func(ctx context.Context, args map[string]interface{}) { + rd := RequestArgs{} + err := mapstructure.Decode(args, &rd) + if err != nil { + s.log.WithError(err).Warning("failed to parse http request") + return + } + s.log.WithField("rd", rd).WithField("raw", args).Debug("request data") + req, err := http.NewRequest(rd.Request.Method, rd.Request.URL, nil) + if err != nil { + s.log.WithError(err).Warning("failed to create request") + return + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: !rd.Request.SSLVerify}, + // todo: timeout + } + c := &http.Client{ + Transport: tr, + } + s.log.WithField("url", req.URL.Host).Debug("sending HTTP request") + res, err := c.Do(req) + if err != nil { + s.log.WithError(err).Warning("failed to send request") + return + } + body, err := io.ReadAll(res.Body) + if err != nil { + s.log.WithError(err).Warning("failed to read body") + return + } + s.log.WithField("res", res.StatusCode).Debug("sending HTTP response") + s.ac.SendWS(ak.WebsocketInstructionProviderSpecific, map[string]interface{}{ + "sub_type": "http_response", + "response_channel": rd.ResponseChannel, + "response": map[string]interface{}{ + "status": res.StatusCode, + "final_url": res.Request.URL.String(), + "headers": res.Header, + "body": base64.StdEncoding.EncodeToString(body), + }, + }) + }) + return nil +} + +func main() { + rootCmd.AddCommand(healthcheck.Command) + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} diff --git a/internal/outpost/ak/api.go b/internal/outpost/ak/api.go index fd8050042a..3cfd4bf5da 100644 --- a/internal/outpost/ak/api.go +++ b/internal/outpost/ak/api.go @@ -95,7 +95,7 @@ func NewAPIController(akURL url.URL, token string) *APIController { time.Sleep(time.Second * 3) } if len(outposts.Results) < 1 { - panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") + panic("No outposts found with given token, ensure the given token corresponds to an authentik Outpost") } outpost := outposts.Results[0] diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_ws.go index d92941f760..c0855b25e8 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_ws.go @@ -233,15 +233,19 @@ func (a *APIController) AddWSHandler(handler WSHandler) { a.wsHandlers = append(a.wsHandlers, handler) } +func (a *APIController) SendWS(inst WebsocketInstruction, args map[string]interface{}) error { + msg := websocketMessage{ + Instruction: inst, + Args: args, + } + err := a.wsConn.WriteJSON(msg) + return err +} + func (a *APIController) SendWSHello(args map[string]interface{}) error { allArgs := a.getWebsocketPingArgs() for key, value := range args { allArgs[key] = value } - aliveMsg := websocketMessage{ - Instruction: WebsocketInstructionHello, - Args: allArgs, - } - err := a.wsConn.WriteJSON(aliveMsg) - return err + return a.SendWS(WebsocketInstructionHello, args) } diff --git a/internal/outpost/ak/api_ws_msg.go b/internal/outpost/ak/api_ws_msg.go index cedecb93d5..1d11860b3f 100644 --- a/internal/outpost/ak/api_ws_msg.go +++ b/internal/outpost/ak/api_ws_msg.go @@ -1,19 +1,19 @@ package ak -type websocketInstruction int +type WebsocketInstruction int const ( // WebsocketInstructionAck Code used to acknowledge a previous message - WebsocketInstructionAck websocketInstruction = 0 + WebsocketInstructionAck WebsocketInstruction = 0 // WebsocketInstructionHello Code used to send a healthcheck keepalive - WebsocketInstructionHello websocketInstruction = 1 + WebsocketInstructionHello WebsocketInstruction = 1 // WebsocketInstructionTriggerUpdate Code received to trigger a config update - WebsocketInstructionTriggerUpdate websocketInstruction = 2 + WebsocketInstructionTriggerUpdate WebsocketInstruction = 2 // WebsocketInstructionProviderSpecific Code received to trigger some provider specific function - WebsocketInstructionProviderSpecific websocketInstruction = 3 + WebsocketInstructionProviderSpecific WebsocketInstruction = 3 ) type websocketMessage struct { - Instruction websocketInstruction `json:"instruction"` + Instruction WebsocketInstruction `json:"instruction"` Args map[string]interface{} `json:"args"` }