implement adapter using outposts
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -37,7 +37,6 @@ class WebsocketMessageInstruction(IntEnum):
|
|||||||
# Provider specific message
|
# Provider specific message
|
||||||
PROVIDER_SPECIFIC = 3
|
PROVIDER_SPECIFIC = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class WebsocketMessage:
|
class WebsocketMessage:
|
||||||
"""Complete Websocket Message that is being sent"""
|
"""Complete Websocket Message that is being sent"""
|
||||||
@ -128,6 +127,12 @@ class OutpostConsumer(JsonWebsocketConsumer):
|
|||||||
state.args.update(msg.args)
|
state.args.update(msg.args)
|
||||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||||
return
|
return
|
||||||
|
elif msg.instruction == WebsocketMessageInstruction.PROVIDER_SPECIFIC:
|
||||||
|
if "response_channel" not in msg.args:
|
||||||
|
return
|
||||||
|
self.logger.debug("Posted response to channel", msg=msg)
|
||||||
|
async_to_sync(self.channel_layer.send)(msg.args.get("response_channel"), content)
|
||||||
|
return
|
||||||
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
||||||
tenant=connection.schema_name,
|
tenant=connection.schema_name,
|
||||||
outpost=self.outpost.name,
|
outpost=self.outpost.name,
|
||||||
|
|||||||
84
authentik/outposts/http.py
Normal file
84
authentik/outposts/http.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from base64 import b64decode
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from random import choice
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from channels.layers import get_channel_layer
|
||||||
|
from channels_redis.pubsub import RedisPubSubChannelLayer
|
||||||
|
from requests.adapters import BaseAdapter
|
||||||
|
from requests.models import PreparedRequest, Response
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
from requests.utils import CaseInsensitiveDict
|
||||||
|
from authentik.outposts.models import Outpost
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutpostPreparedRequest:
|
||||||
|
uid: str
|
||||||
|
method: str
|
||||||
|
url: str
|
||||||
|
headers: dict[str, str]
|
||||||
|
body: Any
|
||||||
|
ssl_verify: bool
|
||||||
|
timeout: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_requests(req: PreparedRequest) -> "OutpostPreparedRequest":
|
||||||
|
return OutpostPreparedRequest(
|
||||||
|
uid=str(uuid4()),
|
||||||
|
method=req.method,
|
||||||
|
url=req.url,
|
||||||
|
headers=req.headers._store,
|
||||||
|
body=req.body,
|
||||||
|
ssl_verify=True,
|
||||||
|
timeout=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_channel(self) -> str:
|
||||||
|
return f"authentik_outpost_http_response_{self.uid}"
|
||||||
|
|
||||||
|
class OutpostHTTPAdapter(BaseAdapter):
|
||||||
|
"""Requests Adapter that sends HTTP requests via a specified Outpost"""
|
||||||
|
|
||||||
|
def __init__(self, outpost: Outpost, default_timeout=10):
|
||||||
|
super().__init__()
|
||||||
|
self.__outpost = outpost
|
||||||
|
self.__logger = get_logger().bind()
|
||||||
|
self.__layer: RedisPubSubChannelLayer = get_channel_layer()
|
||||||
|
self.default_timeout = default_timeout
|
||||||
|
|
||||||
|
def parse_response(self, raw_response: dict, req: PreparedRequest) -> Response:
|
||||||
|
res = Response()
|
||||||
|
res.request = req
|
||||||
|
res.status_code = raw_response.get("status")
|
||||||
|
res.url = raw_response.get("final_url")
|
||||||
|
res.headers = CaseInsensitiveDict(raw_response.get("headers"))
|
||||||
|
res._content = b64decode(raw_response.get("body"))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None):
|
||||||
|
# Convert request so we can send it to the outpost
|
||||||
|
converted = OutpostPreparedRequest.from_requests(request)
|
||||||
|
converted.ssl_verify = verify
|
||||||
|
converted.timeout = timeout if timeout else self.default_timeout
|
||||||
|
# Pick one of the outpost instances
|
||||||
|
state = choice(self.__outpost.state)
|
||||||
|
self.__logger.debug("sending HTTP request to outpost", uid=converted.uid)
|
||||||
|
async_to_sync(self.__layer.send)(
|
||||||
|
state.uid,
|
||||||
|
{
|
||||||
|
"type": "event.provider.specific",
|
||||||
|
"sub_type": "http_request",
|
||||||
|
"response_channel": converted.response_channel,
|
||||||
|
"request": asdict(converted),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.__logger.debug("receiving HTTP response from outpost",uid=converted.uid)
|
||||||
|
raw_response = async_to_sync(self.__layer.receive)(
|
||||||
|
converted.response_channel,
|
||||||
|
)
|
||||||
|
self.__logger.debug("received HTTP response from outpost",uid=converted.uid)
|
||||||
|
return self.parse_response(raw_response.get("args", {}).get("response", {}), request)
|
||||||
@ -98,6 +98,7 @@ class OutpostType(models.TextChoices):
|
|||||||
LDAP = "ldap"
|
LDAP = "ldap"
|
||||||
RADIUS = "radius"
|
RADIUS = "radius"
|
||||||
RAC = "rac"
|
RAC = "rac"
|
||||||
|
SCIM = "scim"
|
||||||
|
|
||||||
|
|
||||||
def default_outpost_config(host: str | None = None):
|
def default_outpost_config(host: str | None = None):
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from authentik.lib.sync.outgoing.exceptions import (
|
|||||||
TransientSyncException,
|
TransientSyncException,
|
||||||
)
|
)
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
|
from authentik.outposts.http import OutpostHTTPAdapter
|
||||||
from authentik.providers.scim.clients.exceptions import SCIMRequestException
|
from authentik.providers.scim.clients.exceptions import SCIMRequestException
|
||||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
|
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
|
||||||
from authentik.providers.scim.models import SCIMProvider
|
from authentik.providers.scim.models import SCIMProvider
|
||||||
@ -41,8 +42,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
|||||||
|
|
||||||
def __init__(self, provider: SCIMProvider):
|
def __init__(self, provider: SCIMProvider):
|
||||||
super().__init__(provider)
|
super().__init__(provider)
|
||||||
self._session = get_http_session()
|
self._session = self.get_session(provider)
|
||||||
self._session.verify = provider.verify_certificates
|
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
# Remove trailing slashes as we assume the URL doesn't have any
|
# Remove trailing slashes as we assume the URL doesn't have any
|
||||||
base_url = provider.url
|
base_url = provider.url
|
||||||
@ -52,6 +52,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
|||||||
self.token = provider.token
|
self.token = provider.token
|
||||||
self._config = self.get_service_provider_config()
|
self._config = self.get_service_provider_config()
|
||||||
|
|
||||||
|
def get_session(self, provider: SCIMProvider):
|
||||||
|
session = get_http_session()
|
||||||
|
if self.provider.outpost_set.exists():
|
||||||
|
adapter = OutpostHTTPAdapter()
|
||||||
|
session.mount("https://", adapter)
|
||||||
|
session.mount("http://", adapter)
|
||||||
|
session.verify = provider.verify_certificates
|
||||||
|
return session
|
||||||
|
|
||||||
def _request(self, method: str, path: str, **kwargs) -> dict:
|
def _request(self, method: str, path: str, **kwargs) -> dict:
|
||||||
"""Wrapper to send a request to the full URL"""
|
"""Wrapper to send a request to the full URL"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
174
cmd/scim/main.go
Normal file
174
cmd/scim/main.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"goauthentik.io/internal/common"
|
||||||
|
"goauthentik.io/internal/debug"
|
||||||
|
"goauthentik.io/internal/outpost/ak"
|
||||||
|
"goauthentik.io/internal/outpost/ak/healthcheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
const helpMessage = `authentik SCIM
|
||||||
|
|
||||||
|
Required environment variables:
|
||||||
|
- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company")
|
||||||
|
- AUTHENTIK_TOKEN: Token to authenticate with
|
||||||
|
- AUTHENTIK_INSECURE: Skip SSL Certificate verification`
|
||||||
|
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Long: helpMessage,
|
||||||
|
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
log.SetFormatter(&log.JSONFormatter{
|
||||||
|
FieldMap: log.FieldMap{
|
||||||
|
log.FieldKeyMsg: "event",
|
||||||
|
log.FieldKeyTime: "timestamp",
|
||||||
|
},
|
||||||
|
DisableHTMLEscape: true,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
debug.EnableDebugServer()
|
||||||
|
akURL, found := os.LookupEnv("AUTHENTIK_HOST")
|
||||||
|
if !found {
|
||||||
|
fmt.Println("env AUTHENTIK_HOST not set!")
|
||||||
|
fmt.Println(helpMessage)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
akToken, found := os.LookupEnv("AUTHENTIK_TOKEN")
|
||||||
|
if !found {
|
||||||
|
fmt.Println("env AUTHENTIK_TOKEN not set!")
|
||||||
|
fmt.Println(helpMessage)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
akURLActual, err := url.Parse(akURL)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
fmt.Println(helpMessage)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ex := common.Init()
|
||||||
|
defer common.Defer()
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
<-ex
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ac := ak.NewAPIController(*akURLActual, akToken)
|
||||||
|
if ac == nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer ac.Shutdown()
|
||||||
|
|
||||||
|
ac.Server = &SCIMOutpost{
|
||||||
|
ac: ac,
|
||||||
|
log: log.WithField("logger", "authentik.outpost.scim"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ac.Start()
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Panic("Failed to run server")
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
<-ex
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPRequest struct {
|
||||||
|
Uid string `mapstructure:"uid"`
|
||||||
|
Method string `mapstructure:"method"`
|
||||||
|
URL string `mapstructure:"url"`
|
||||||
|
Headers map[string][]string `mapstructure:"headers"`
|
||||||
|
Body interface{} `mapstructure:"body"`
|
||||||
|
SSLVerify bool `mapstructure:"ssl_verify"`
|
||||||
|
Timeout int `mapstructure:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestArgs struct {
|
||||||
|
Request HTTPRequest `mapstructure:"request"`
|
||||||
|
ResponseChannel string `mapstructure:"response_channel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SCIMOutpost struct {
|
||||||
|
ac *ak.APIController
|
||||||
|
log *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SCIMOutpost) Type() string { return "SCIM" }
|
||||||
|
func (s *SCIMOutpost) Stop() error { return nil }
|
||||||
|
func (s *SCIMOutpost) Refresh() error { return nil }
|
||||||
|
func (s *SCIMOutpost) TimerFlowCacheExpiry(context.Context) {}
|
||||||
|
|
||||||
|
func (s *SCIMOutpost) Start() error {
|
||||||
|
s.ac.AddWSHandler(func(ctx context.Context, args map[string]interface{}) {
|
||||||
|
rd := RequestArgs{}
|
||||||
|
err := mapstructure.Decode(args, &rd)
|
||||||
|
if err != nil {
|
||||||
|
s.log.WithError(err).Warning("failed to parse http request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.log.WithField("rd", rd).WithField("raw", args).Debug("request data")
|
||||||
|
req, err := http.NewRequest(rd.Request.Method, rd.Request.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.log.WithError(err).Warning("failed to create request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: !rd.Request.SSLVerify},
|
||||||
|
// todo: timeout
|
||||||
|
}
|
||||||
|
c := &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
s.log.WithField("url", req.URL.Host).Debug("sending HTTP request")
|
||||||
|
res, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
s.log.WithError(err).Warning("failed to send request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
s.log.WithError(err).Warning("failed to read body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.log.WithField("res", res.StatusCode).Debug("sending HTTP response")
|
||||||
|
s.ac.SendWS(ak.WebsocketInstructionProviderSpecific, map[string]interface{}{
|
||||||
|
"sub_type": "http_response",
|
||||||
|
"response_channel": rd.ResponseChannel,
|
||||||
|
"response": map[string]interface{}{
|
||||||
|
"status": res.StatusCode,
|
||||||
|
"final_url": res.Request.URL.String(),
|
||||||
|
"headers": res.Header,
|
||||||
|
"body": base64.StdEncoding.EncodeToString(body),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
rootCmd.AddCommand(healthcheck.Command)
|
||||||
|
err := rootCmd.Execute()
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -95,7 +95,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
|||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
}
|
}
|
||||||
if len(outposts.Results) < 1 {
|
if len(outposts.Results) < 1 {
|
||||||
panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost")
|
panic("No outposts found with given token, ensure the given token corresponds to an authentik Outpost")
|
||||||
}
|
}
|
||||||
outpost := outposts.Results[0]
|
outpost := outposts.Results[0]
|
||||||
|
|
||||||
|
|||||||
@ -233,15 +233,19 @@ func (a *APIController) AddWSHandler(handler WSHandler) {
|
|||||||
a.wsHandlers = append(a.wsHandlers, handler)
|
a.wsHandlers = append(a.wsHandlers, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *APIController) SendWS(inst WebsocketInstruction, args map[string]interface{}) error {
|
||||||
|
msg := websocketMessage{
|
||||||
|
Instruction: inst,
|
||||||
|
Args: args,
|
||||||
|
}
|
||||||
|
err := a.wsConn.WriteJSON(msg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (a *APIController) SendWSHello(args map[string]interface{}) error {
|
func (a *APIController) SendWSHello(args map[string]interface{}) error {
|
||||||
allArgs := a.getWebsocketPingArgs()
|
allArgs := a.getWebsocketPingArgs()
|
||||||
for key, value := range args {
|
for key, value := range args {
|
||||||
allArgs[key] = value
|
allArgs[key] = value
|
||||||
}
|
}
|
||||||
aliveMsg := websocketMessage{
|
return a.SendWS(WebsocketInstructionHello, args)
|
||||||
Instruction: WebsocketInstructionHello,
|
|
||||||
Args: allArgs,
|
|
||||||
}
|
|
||||||
err := a.wsConn.WriteJSON(aliveMsg)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,19 +1,19 @@
|
|||||||
package ak
|
package ak
|
||||||
|
|
||||||
type websocketInstruction int
|
type WebsocketInstruction int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// WebsocketInstructionAck Code used to acknowledge a previous message
|
// WebsocketInstructionAck Code used to acknowledge a previous message
|
||||||
WebsocketInstructionAck websocketInstruction = 0
|
WebsocketInstructionAck WebsocketInstruction = 0
|
||||||
// WebsocketInstructionHello Code used to send a healthcheck keepalive
|
// WebsocketInstructionHello Code used to send a healthcheck keepalive
|
||||||
WebsocketInstructionHello websocketInstruction = 1
|
WebsocketInstructionHello WebsocketInstruction = 1
|
||||||
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
||||||
WebsocketInstructionTriggerUpdate websocketInstruction = 2
|
WebsocketInstructionTriggerUpdate WebsocketInstruction = 2
|
||||||
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
|
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
|
||||||
WebsocketInstructionProviderSpecific websocketInstruction = 3
|
WebsocketInstructionProviderSpecific WebsocketInstruction = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
type websocketMessage struct {
|
type websocketMessage struct {
|
||||||
Instruction websocketInstruction `json:"instruction"`
|
Instruction WebsocketInstruction `json:"instruction"`
|
||||||
Args map[string]interface{} `json:"args"`
|
Args map[string]interface{} `json:"args"`
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user