Compare commits

...

4 Commits

Author SHA1 Message Date
396925d1f0 add timeout
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:43:09 +01:00
10a8ed164e small fixes
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:36:56 +01:00
445dc01dca add full outpost support
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:36:56 +01:00
441916703d implement adapter using outposts
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:36:56 +01:00
15 changed files with 349 additions and 17 deletions

View File

@ -128,6 +128,12 @@ class OutpostConsumer(JsonWebsocketConsumer):
state.args.update(msg.args) state.args.update(msg.args)
elif msg.instruction == WebsocketMessageInstruction.ACK: elif msg.instruction == WebsocketMessageInstruction.ACK:
return return
elif msg.instruction == WebsocketMessageInstruction.PROVIDER_SPECIFIC:
if "response_channel" not in msg.args:
return
self.logger.debug("Posted response to channel", msg=msg)
async_to_sync(self.channel_layer.send)(msg.args.get("response_channel"), content)
return
GAUGE_OUTPOSTS_LAST_UPDATE.labels( GAUGE_OUTPOSTS_LAST_UPDATE.labels(
tenant=connection.schema_name, tenant=connection.schema_name,
outpost=self.outpost.name, outpost=self.outpost.name,

View File

@ -0,0 +1,86 @@
from base64 import b64decode
from dataclasses import asdict, dataclass
from random import choice
from typing import Any
from uuid import uuid4
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from channels_redis.pubsub import RedisPubSubChannelLayer
from requests.adapters import BaseAdapter
from requests.models import PreparedRequest, Response
from requests.utils import CaseInsensitiveDict
from structlog.stdlib import get_logger
from authentik.outposts.models import Outpost
@dataclass
class OutpostPreparedRequest:
uid: str
method: str
url: str
headers: dict[str, str]
body: Any
ssl_verify: bool
timeout: int
@staticmethod
def from_requests(req: PreparedRequest) -> "OutpostPreparedRequest":
return OutpostPreparedRequest(
uid=str(uuid4()),
method=req.method,
url=req.url,
headers=req.headers._store,
body=req.body,
ssl_verify=True,
timeout=0,
)
@property
def response_channel(self) -> str:
return f"authentik_outpost_http_response_{self.uid}"
class OutpostHTTPAdapter(BaseAdapter):
"""Requests Adapter that sends HTTP requests via a specified Outpost"""
def __init__(self, outpost: Outpost, default_timeout=10):
super().__init__()
self.__outpost = outpost
self.__logger = get_logger().bind()
self.__layer: RedisPubSubChannelLayer = get_channel_layer()
self.default_timeout = default_timeout
def parse_response(self, raw_response: dict, req: PreparedRequest) -> Response:
res = Response()
res.request = req
res.status_code = raw_response.get("status")
res.url = raw_response.get("final_url")
res.headers = CaseInsensitiveDict(raw_response.get("headers"))
res._content = b64decode(raw_response.get("body"))
return res
def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None):
# Convert request so we can send it to the outpost
converted = OutpostPreparedRequest.from_requests(request)
converted.ssl_verify = verify
converted.timeout = timeout if timeout else self.default_timeout
# Pick one of the outpost instances
state = choice(self.__outpost.state) # nosec
self.__logger.debug("sending HTTP request to outpost", uid=converted.uid)
async_to_sync(self.__layer.send)(
state.uid,
{
"type": "event.provider.specific",
"sub_type": "http_request",
"response_channel": converted.response_channel,
"request": asdict(converted),
},
)
self.__logger.debug("receiving HTTP response from outpost", uid=converted.uid)
raw_response = async_to_sync(self.__layer.receive)(
converted.response_channel,
)
self.__logger.debug("received HTTP response from outpost", uid=converted.uid)
return self.parse_response(raw_response.get("args", {}).get("response", {}), request)

View File

@ -98,6 +98,7 @@ class OutpostType(models.TextChoices):
LDAP = "ldap" LDAP = "ldap"
RADIUS = "radius" RADIUS = "radius"
RAC = "rac" RAC = "rac"
SCIM = "scim"
def default_outpost_config(host: str | None = None): def default_outpost_config(host: str | None = None):

View File

@ -43,13 +43,15 @@ from authentik.providers.proxy.controllers.docker import ProxyDockerController
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
from authentik.providers.radius.controllers.docker import RadiusDockerController from authentik.providers.radius.controllers.docker import RadiusDockerController
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
from authentik.providers.scim.controllers.docker import SCIMDockerController
from authentik.providers.scim.controllers.kubernetes import SCIMKubernetesController
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: # noqa: PLR0911
"""Get a controller for the outpost, when a service connection is defined""" """Get a controller for the outpost, when a service connection is defined"""
if not outpost.service_connection: if not outpost.service_connection:
return None return None
@ -74,6 +76,11 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
return RACDockerController return RACDockerController
if isinstance(service_connection, KubernetesServiceConnection): if isinstance(service_connection, KubernetesServiceConnection):
return RACKubernetesController return RACKubernetesController
if outpost.type == OutpostType.SCIM:
if isinstance(service_connection, DockerServiceConnection):
return SCIMDockerController
if isinstance(service_connection, KubernetesServiceConnection):
return SCIMKubernetesController
return None return None

View File

@ -19,6 +19,7 @@ from authentik.lib.sync.outgoing.exceptions import (
TransientSyncException, TransientSyncException,
) )
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.outposts.http import OutpostHTTPAdapter
from authentik.providers.scim.clients.exceptions import SCIMRequestException from authentik.providers.scim.clients.exceptions import SCIMRequestException
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
@ -41,8 +42,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
def __init__(self, provider: SCIMProvider): def __init__(self, provider: SCIMProvider):
super().__init__(provider) super().__init__(provider)
self._session = get_http_session() self._session = self.get_session(provider)
self._session.verify = provider.verify_certificates
self.provider = provider self.provider = provider
# Remove trailing slashes as we assume the URL doesn't have any # Remove trailing slashes as we assume the URL doesn't have any
base_url = provider.url base_url = provider.url
@ -52,6 +52,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
self.token = provider.token self.token = provider.token
self._config = self.get_service_provider_config() self._config = self.get_service_provider_config()
def get_session(self, provider: SCIMProvider):
session = get_http_session()
if self.provider.outpost_set.exists():
adapter = OutpostHTTPAdapter()
session.mount("https://", adapter)
session.mount("http://", adapter)
session.verify = provider.verify_certificates
return session
def _request(self, method: str, path: str, **kwargs) -> dict: def _request(self, method: str, path: str, **kwargs) -> dict:
"""Wrapper to send a request to the full URL""" """Wrapper to send a request to the full URL"""
try: try:

View File

@ -0,0 +1,12 @@
"""SCIM Provider Docker Controller"""
from authentik.outposts.controllers.docker import DockerController
from authentik.outposts.models import DockerServiceConnection, Outpost
class SCIMDockerController(DockerController):
"""SCIM Provider Docker Controller"""
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []

View File

@ -0,0 +1,14 @@
"""SCIM Provider Kubernetes Controller"""
from authentik.outposts.controllers.k8s.service import ServiceReconciler
from authentik.outposts.controllers.kubernetes import KubernetesController
from authentik.outposts.models import KubernetesServiceConnection, Outpost
class SCIMKubernetesController(KubernetesController):
"""SCIM Provider Kubernetes Controller"""
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []
del self.reconcilers[ServiceReconciler.reconciler_name()]

View File

@ -4381,7 +4381,8 @@
"proxy", "proxy",
"ldap", "ldap",
"radius", "radius",
"rac" "rac",
"scim"
], ],
"title": "Type" "title": "Type"
}, },

185
cmd/scim/main.go Normal file
View File

@ -0,0 +1,185 @@
package main
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"goauthentik.io/internal/common"
"goauthentik.io/internal/debug"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ak/healthcheck"
)
const helpMessage = `authentik SCIM
Required environment variables:
- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company")
- AUTHENTIK_TOKEN: Token to authenticate with
- AUTHENTIK_INSECURE: Skip SSL Certificate verification`
var rootCmd = &cobra.Command{
Long: helpMessage,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
log.SetLevel(log.DebugLevel)
log.SetFormatter(&log.JSONFormatter{
FieldMap: log.FieldMap{
log.FieldKeyMsg: "event",
log.FieldKeyTime: "timestamp",
},
DisableHTMLEscape: true,
})
},
Run: func(cmd *cobra.Command, args []string) {
debug.EnableDebugServer()
akURL, found := os.LookupEnv("AUTHENTIK_HOST")
if !found {
fmt.Println("env AUTHENTIK_HOST not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akToken, found := os.LookupEnv("AUTHENTIK_TOKEN")
if !found {
fmt.Println("env AUTHENTIK_TOKEN not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akURLActual, err := url.Parse(akURL)
if err != nil {
fmt.Println(err)
fmt.Println(helpMessage)
os.Exit(1)
}
ex := common.Init()
defer common.Defer()
go func() {
for {
<-ex
os.Exit(0)
}
}()
ac := ak.NewAPIController(*akURLActual, akToken)
if ac == nil {
os.Exit(1)
}
defer ac.Shutdown()
ac.Server = &SCIMOutpost{
ac: ac,
log: log.WithField("logger", "authentik.outpost.scim"),
}
err = ac.Start()
if err != nil {
log.WithError(err).Panic("Failed to run server")
}
for {
<-ex
}
},
}
type HTTPRequest struct {
Uid string `mapstructure:"uid"`
Method string `mapstructure:"method"`
URL string `mapstructure:"url"`
Headers map[string][]string `mapstructure:"headers"`
Body interface{} `mapstructure:"body"`
SSLVerify bool `mapstructure:"ssl_verify"`
Timeout int `mapstructure:"timeout"`
}
type RequestArgs struct {
Request HTTPRequest `mapstructure:"request"`
ResponseChannel string `mapstructure:"response_channel"`
}
type SCIMOutpost struct {
ac *ak.APIController
log *log.Entry
}
func (s *SCIMOutpost) Type() string { return "SCIM" }
func (s *SCIMOutpost) Stop() error { return nil }
func (s *SCIMOutpost) Refresh() error { return nil }
func (s *SCIMOutpost) TimerFlowCacheExpiry(context.Context) {}
func (s *SCIMOutpost) Start() error {
s.ac.AddWSHandler(func(ctx context.Context, args map[string]interface{}) {
rd := RequestArgs{}
err := mapstructure.Decode(args, &rd)
if err != nil {
s.log.WithError(err).Warning("failed to parse http request")
return
}
s.log.WithField("rd", rd).WithField("raw", args).Debug("request data")
ctx, canc := context.WithTimeout(ctx, time.Duration(rd.Request.Timeout)*time.Second)
defer canc()
req, err := http.NewRequestWithContext(ctx, rd.Request.Method, rd.Request.URL, nil)
if err != nil {
s.log.WithError(err).Warning("failed to create request")
return
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: !rd.Request.SSLVerify},
TLSHandshakeTimeout: time.Duration(rd.Request.Timeout) * time.Second,
IdleConnTimeout: time.Duration(rd.Request.Timeout) * time.Second,
ResponseHeaderTimeout: time.Duration(rd.Request.Timeout) * time.Second,
ExpectContinueTimeout: time.Duration(rd.Request.Timeout) * time.Second,
}
c := &http.Client{
Transport: tr,
}
s.log.WithField("url", req.URL.Host).Debug("sending HTTP request")
res, err := c.Do(req)
if err != nil {
s.log.WithError(err).Warning("failed to send request")
return
}
body, err := io.ReadAll(res.Body)
if err != nil {
s.log.WithError(err).Warning("failed to read body")
return
}
s.log.WithField("res", res.StatusCode).Debug("sending HTTP response")
err = s.ac.SendWS(ak.WebsocketInstructionProviderSpecific, map[string]interface{}{
"sub_type": "http_response",
"response_channel": rd.ResponseChannel,
"response": map[string]interface{}{
"status": res.StatusCode,
"final_url": res.Request.URL.String(),
"headers": res.Header,
"body": base64.StdEncoding.EncodeToString(body),
},
})
if err != nil {
s.log.WithError(err).Warning("failed to send http response")
return
}
})
return nil
}
func main() {
rootCmd.AddCommand(healthcheck.Command)
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

View File

@ -95,7 +95,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
} }
if len(outposts.Results) < 1 { if len(outposts.Results) < 1 {
panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") panic("No outposts found with given token, ensure the given token corresponds to an authentik Outpost")
} }
outpost := outposts.Results[0] outpost := outposts.Results[0]

View File

@ -233,15 +233,19 @@ func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler) a.wsHandlers = append(a.wsHandlers, handler)
} }
func (a *APIController) SendWS(inst WebsocketInstruction, args map[string]interface{}) error {
msg := websocketMessage{
Instruction: inst,
Args: args,
}
err := a.wsConn.WriteJSON(msg)
return err
}
func (a *APIController) SendWSHello(args map[string]interface{}) error { func (a *APIController) SendWSHello(args map[string]interface{}) error {
allArgs := a.getWebsocketPingArgs() allArgs := a.getWebsocketPingArgs()
for key, value := range args { for key, value := range args {
allArgs[key] = value allArgs[key] = value
} }
aliveMsg := websocketMessage{ return a.SendWS(WebsocketInstructionHello, args)
Instruction: WebsocketInstructionHello,
Args: allArgs,
}
err := a.wsConn.WriteJSON(aliveMsg)
return err
} }

View File

@ -1,19 +1,19 @@
package ak package ak
type websocketInstruction int type WebsocketInstruction int
const ( const (
// WebsocketInstructionAck Code used to acknowledge a previous message // WebsocketInstructionAck Code used to acknowledge a previous message
WebsocketInstructionAck websocketInstruction = 0 WebsocketInstructionAck WebsocketInstruction = 0
// WebsocketInstructionHello Code used to send a healthcheck keepalive // WebsocketInstructionHello Code used to send a healthcheck keepalive
WebsocketInstructionHello websocketInstruction = 1 WebsocketInstructionHello WebsocketInstruction = 1
// WebsocketInstructionTriggerUpdate Code received to trigger a config update // WebsocketInstructionTriggerUpdate Code received to trigger a config update
WebsocketInstructionTriggerUpdate websocketInstruction = 2 WebsocketInstructionTriggerUpdate WebsocketInstruction = 2
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function // WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
WebsocketInstructionProviderSpecific websocketInstruction = 3 WebsocketInstructionProviderSpecific WebsocketInstruction = 3
) )
type websocketMessage struct { type websocketMessage struct {
Instruction websocketInstruction `json:"instruction"` Instruction WebsocketInstruction `json:"instruction"`
Args map[string]interface{} `json:"args"` Args map[string]interface{} `json:"args"`
} }

View File

@ -46705,6 +46705,7 @@ components:
- ldap - ldap
- radius - radius
- rac - rac
- scim
type: string type: string
PaginatedApplicationEntitlementList: PaginatedApplicationEntitlementList:
type: object type: object

View File

@ -73,6 +73,9 @@ const radiusListFetch = async (page: number, search = "") =>
const racListProvider = async (page: number, search = "") => const racListProvider = async (page: number, search = "") =>
provisionMaker(await api().providersRacList(providerListArgs(page, search))); provisionMaker(await api().providersRacList(providerListArgs(page, search)));
const scimListProvider = async (page: number, search = "") =>
provisionMaker(await api().providersScimList(providerListArgs(page, search)));
function providerProvider(type: OutpostTypeEnum): DataProvider { function providerProvider(type: OutpostTypeEnum): DataProvider {
switch (type) { switch (type) {
case OutpostTypeEnum.Proxy: case OutpostTypeEnum.Proxy:
@ -83,6 +86,8 @@ function providerProvider(type: OutpostTypeEnum): DataProvider {
return radiusListFetch; return radiusListFetch;
case OutpostTypeEnum.Rac: case OutpostTypeEnum.Rac:
return racListProvider; return racListProvider;
case OutpostTypeEnum.Scim:
return scimListProvider;
default: default:
throw new Error(`Unrecognized OutputType: ${type}`); throw new Error(`Unrecognized OutputType: ${type}`);
} }
@ -142,6 +147,7 @@ export class OutpostForm extends ModelForm<Outpost, string> {
[OutpostTypeEnum.Ldap, msg("LDAP")], [OutpostTypeEnum.Ldap, msg("LDAP")],
[OutpostTypeEnum.Radius, msg("Radius")], [OutpostTypeEnum.Radius, msg("Radius")],
[OutpostTypeEnum.Rac, msg("RAC")], [OutpostTypeEnum.Rac, msg("RAC")],
[OutpostTypeEnum.Scim, msg("SCIM")],
]; ];
return html` <ak-form-element-horizontal label=${msg("Name")} ?required=${true} name="name"> return html` <ak-form-element-horizontal label=${msg("Name")} ?required=${true} name="name">