Compare commits

...

4 Commits

Author SHA1 Message Date
396925d1f0 add timeout
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:43:09 +01:00
10a8ed164e small fixes
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:36:56 +01:00
445dc01dca add full outpost support
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:36:56 +01:00
441916703d implement adapter using outposts
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-02-17 02:36:56 +01:00
15 changed files with 349 additions and 17 deletions

View File

@ -128,6 +128,12 @@ class OutpostConsumer(JsonWebsocketConsumer):
state.args.update(msg.args)
elif msg.instruction == WebsocketMessageInstruction.ACK:
return
elif msg.instruction == WebsocketMessageInstruction.PROVIDER_SPECIFIC:
if "response_channel" not in msg.args:
return
self.logger.debug("Posted response to channel", msg=msg)
async_to_sync(self.channel_layer.send)(msg.args.get("response_channel"), content)
return
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
tenant=connection.schema_name,
outpost=self.outpost.name,

View File

@ -0,0 +1,86 @@
from base64 import b64decode
from dataclasses import asdict, dataclass
from random import choice
from typing import Any
from uuid import uuid4
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from channels_redis.pubsub import RedisPubSubChannelLayer
from requests.adapters import BaseAdapter
from requests.models import PreparedRequest, Response
from requests.utils import CaseInsensitiveDict
from structlog.stdlib import get_logger
from authentik.outposts.models import Outpost
@dataclass
class OutpostPreparedRequest:
uid: str
method: str
url: str
headers: dict[str, str]
body: Any
ssl_verify: bool
timeout: int
@staticmethod
def from_requests(req: PreparedRequest) -> "OutpostPreparedRequest":
return OutpostPreparedRequest(
uid=str(uuid4()),
method=req.method,
url=req.url,
headers=req.headers._store,
body=req.body,
ssl_verify=True,
timeout=0,
)
@property
def response_channel(self) -> str:
return f"authentik_outpost_http_response_{self.uid}"
class OutpostHTTPAdapter(BaseAdapter):
"""Requests Adapter that sends HTTP requests via a specified Outpost"""
def __init__(self, outpost: Outpost, default_timeout=10):
super().__init__()
self.__outpost = outpost
self.__logger = get_logger().bind()
self.__layer: RedisPubSubChannelLayer = get_channel_layer()
self.default_timeout = default_timeout
def parse_response(self, raw_response: dict, req: PreparedRequest) -> Response:
res = Response()
res.request = req
res.status_code = raw_response.get("status")
res.url = raw_response.get("final_url")
res.headers = CaseInsensitiveDict(raw_response.get("headers"))
res._content = b64decode(raw_response.get("body"))
return res
def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None):
# Convert request so we can send it to the outpost
converted = OutpostPreparedRequest.from_requests(request)
converted.ssl_verify = verify
converted.timeout = timeout if timeout else self.default_timeout
# Pick one of the outpost instances
state = choice(self.__outpost.state) # nosec
self.__logger.debug("sending HTTP request to outpost", uid=converted.uid)
async_to_sync(self.__layer.send)(
state.uid,
{
"type": "event.provider.specific",
"sub_type": "http_request",
"response_channel": converted.response_channel,
"request": asdict(converted),
},
)
self.__logger.debug("receiving HTTP response from outpost", uid=converted.uid)
raw_response = async_to_sync(self.__layer.receive)(
converted.response_channel,
)
self.__logger.debug("received HTTP response from outpost", uid=converted.uid)
return self.parse_response(raw_response.get("args", {}).get("response", {}), request)

View File

@ -98,6 +98,7 @@ class OutpostType(models.TextChoices):
LDAP = "ldap"
RADIUS = "radius"
RAC = "rac"
SCIM = "scim"
def default_outpost_config(host: str | None = None):

View File

@ -43,13 +43,15 @@ from authentik.providers.proxy.controllers.docker import ProxyDockerController
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
from authentik.providers.radius.controllers.docker import RadiusDockerController
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
from authentik.providers.scim.controllers.docker import SCIMDockerController
from authentik.providers.scim.controllers.kubernetes import SCIMKubernetesController
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: # noqa: PLR0911
"""Get a controller for the outpost, when a service connection is defined"""
if not outpost.service_connection:
return None
@ -74,6 +76,11 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
return RACDockerController
if isinstance(service_connection, KubernetesServiceConnection):
return RACKubernetesController
if outpost.type == OutpostType.SCIM:
if isinstance(service_connection, DockerServiceConnection):
return SCIMDockerController
if isinstance(service_connection, KubernetesServiceConnection):
return SCIMKubernetesController
return None

View File

@ -19,6 +19,7 @@ from authentik.lib.sync.outgoing.exceptions import (
TransientSyncException,
)
from authentik.lib.utils.http import get_http_session
from authentik.outposts.http import OutpostHTTPAdapter
from authentik.providers.scim.clients.exceptions import SCIMRequestException
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMProvider
@ -41,8 +42,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
def __init__(self, provider: SCIMProvider):
super().__init__(provider)
self._session = get_http_session()
self._session.verify = provider.verify_certificates
self._session = self.get_session(provider)
self.provider = provider
# Remove trailing slashes as we assume the URL doesn't have any
base_url = provider.url
@ -52,6 +52,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
self.token = provider.token
self._config = self.get_service_provider_config()
def get_session(self, provider: SCIMProvider):
session = get_http_session()
if self.provider.outpost_set.exists():
adapter = OutpostHTTPAdapter()
session.mount("https://", adapter)
session.mount("http://", adapter)
session.verify = provider.verify_certificates
return session
def _request(self, method: str, path: str, **kwargs) -> dict:
"""Wrapper to send a request to the full URL"""
try:

View File

@ -0,0 +1,12 @@
"""SCIM Provider Docker Controller"""
from authentik.outposts.controllers.docker import DockerController
from authentik.outposts.models import DockerServiceConnection, Outpost
class SCIMDockerController(DockerController):
"""SCIM Provider Docker Controller"""
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []

View File

@ -0,0 +1,14 @@
"""SCIM Provider Kubernetes Controller"""
from authentik.outposts.controllers.k8s.service import ServiceReconciler
from authentik.outposts.controllers.kubernetes import KubernetesController
from authentik.outposts.models import KubernetesServiceConnection, Outpost
class SCIMKubernetesController(KubernetesController):
"""SCIM Provider Kubernetes Controller"""
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []
del self.reconcilers[ServiceReconciler.reconciler_name()]

View File

@ -4381,7 +4381,8 @@
"proxy",
"ldap",
"radius",
"rac"
"rac",
"scim"
],
"title": "Type"
},

185
cmd/scim/main.go Normal file
View File

@ -0,0 +1,185 @@
package main
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"goauthentik.io/internal/common"
"goauthentik.io/internal/debug"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ak/healthcheck"
)
const helpMessage = `authentik SCIM
Required environment variables:
- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company")
- AUTHENTIK_TOKEN: Token to authenticate with
- AUTHENTIK_INSECURE: Skip SSL Certificate verification`
var rootCmd = &cobra.Command{
Long: helpMessage,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
log.SetLevel(log.DebugLevel)
log.SetFormatter(&log.JSONFormatter{
FieldMap: log.FieldMap{
log.FieldKeyMsg: "event",
log.FieldKeyTime: "timestamp",
},
DisableHTMLEscape: true,
})
},
Run: func(cmd *cobra.Command, args []string) {
debug.EnableDebugServer()
akURL, found := os.LookupEnv("AUTHENTIK_HOST")
if !found {
fmt.Println("env AUTHENTIK_HOST not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akToken, found := os.LookupEnv("AUTHENTIK_TOKEN")
if !found {
fmt.Println("env AUTHENTIK_TOKEN not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akURLActual, err := url.Parse(akURL)
if err != nil {
fmt.Println(err)
fmt.Println(helpMessage)
os.Exit(1)
}
ex := common.Init()
defer common.Defer()
go func() {
for {
<-ex
os.Exit(0)
}
}()
ac := ak.NewAPIController(*akURLActual, akToken)
if ac == nil {
os.Exit(1)
}
defer ac.Shutdown()
ac.Server = &SCIMOutpost{
ac: ac,
log: log.WithField("logger", "authentik.outpost.scim"),
}
err = ac.Start()
if err != nil {
log.WithError(err).Panic("Failed to run server")
}
for {
<-ex
}
},
}
type HTTPRequest struct {
Uid string `mapstructure:"uid"`
Method string `mapstructure:"method"`
URL string `mapstructure:"url"`
Headers map[string][]string `mapstructure:"headers"`
Body interface{} `mapstructure:"body"`
SSLVerify bool `mapstructure:"ssl_verify"`
Timeout int `mapstructure:"timeout"`
}
type RequestArgs struct {
Request HTTPRequest `mapstructure:"request"`
ResponseChannel string `mapstructure:"response_channel"`
}
type SCIMOutpost struct {
ac *ak.APIController
log *log.Entry
}
func (s *SCIMOutpost) Type() string { return "SCIM" }
func (s *SCIMOutpost) Stop() error { return nil }
func (s *SCIMOutpost) Refresh() error { return nil }
func (s *SCIMOutpost) TimerFlowCacheExpiry(context.Context) {}
func (s *SCIMOutpost) Start() error {
s.ac.AddWSHandler(func(ctx context.Context, args map[string]interface{}) {
rd := RequestArgs{}
err := mapstructure.Decode(args, &rd)
if err != nil {
s.log.WithError(err).Warning("failed to parse http request")
return
}
s.log.WithField("rd", rd).WithField("raw", args).Debug("request data")
ctx, canc := context.WithTimeout(ctx, time.Duration(rd.Request.Timeout)*time.Second)
defer canc()
req, err := http.NewRequestWithContext(ctx, rd.Request.Method, rd.Request.URL, nil)
if err != nil {
s.log.WithError(err).Warning("failed to create request")
return
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: !rd.Request.SSLVerify},
TLSHandshakeTimeout: time.Duration(rd.Request.Timeout) * time.Second,
IdleConnTimeout: time.Duration(rd.Request.Timeout) * time.Second,
ResponseHeaderTimeout: time.Duration(rd.Request.Timeout) * time.Second,
ExpectContinueTimeout: time.Duration(rd.Request.Timeout) * time.Second,
}
c := &http.Client{
Transport: tr,
}
s.log.WithField("url", req.URL.Host).Debug("sending HTTP request")
res, err := c.Do(req)
if err != nil {
s.log.WithError(err).Warning("failed to send request")
return
}
body, err := io.ReadAll(res.Body)
if err != nil {
s.log.WithError(err).Warning("failed to read body")
return
}
s.log.WithField("res", res.StatusCode).Debug("sending HTTP response")
err = s.ac.SendWS(ak.WebsocketInstructionProviderSpecific, map[string]interface{}{
"sub_type": "http_response",
"response_channel": rd.ResponseChannel,
"response": map[string]interface{}{
"status": res.StatusCode,
"final_url": res.Request.URL.String(),
"headers": res.Header,
"body": base64.StdEncoding.EncodeToString(body),
},
})
if err != nil {
s.log.WithError(err).Warning("failed to send http response")
return
}
})
return nil
}
func main() {
rootCmd.AddCommand(healthcheck.Command)
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

View File

@ -95,7 +95,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
time.Sleep(time.Second * 3)
}
if len(outposts.Results) < 1 {
panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost")
panic("No outposts found with given token, ensure the given token corresponds to an authentik Outpost")
}
outpost := outposts.Results[0]

View File

@ -233,15 +233,19 @@ func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler)
}
func (a *APIController) SendWS(inst WebsocketInstruction, args map[string]interface{}) error {
msg := websocketMessage{
Instruction: inst,
Args: args,
}
err := a.wsConn.WriteJSON(msg)
return err
}
func (a *APIController) SendWSHello(args map[string]interface{}) error {
allArgs := a.getWebsocketPingArgs()
for key, value := range args {
allArgs[key] = value
}
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: allArgs,
}
err := a.wsConn.WriteJSON(aliveMsg)
return err
return a.SendWS(WebsocketInstructionHello, args)
}

View File

@ -1,19 +1,19 @@
package ak
type websocketInstruction int
type WebsocketInstruction int
const (
// WebsocketInstructionAck Code used to acknowledge a previous message
WebsocketInstructionAck websocketInstruction = 0
WebsocketInstructionAck WebsocketInstruction = 0
// WebsocketInstructionHello Code used to send a healthcheck keepalive
WebsocketInstructionHello websocketInstruction = 1
WebsocketInstructionHello WebsocketInstruction = 1
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
WebsocketInstructionTriggerUpdate websocketInstruction = 2
WebsocketInstructionTriggerUpdate WebsocketInstruction = 2
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
WebsocketInstructionProviderSpecific websocketInstruction = 3
WebsocketInstructionProviderSpecific WebsocketInstruction = 3
)
type websocketMessage struct {
Instruction websocketInstruction `json:"instruction"`
Instruction WebsocketInstruction `json:"instruction"`
Args map[string]interface{} `json:"args"`
}

View File

@ -46705,6 +46705,7 @@ components:
- ldap
- radius
- rac
- scim
type: string
PaginatedApplicationEntitlementList:
type: object

View File

@ -73,6 +73,9 @@ const radiusListFetch = async (page: number, search = "") =>
const racListProvider = async (page: number, search = "") =>
provisionMaker(await api().providersRacList(providerListArgs(page, search)));
const scimListProvider = async (page: number, search = "") =>
provisionMaker(await api().providersScimList(providerListArgs(page, search)));
function providerProvider(type: OutpostTypeEnum): DataProvider {
switch (type) {
case OutpostTypeEnum.Proxy:
@ -83,6 +86,8 @@ function providerProvider(type: OutpostTypeEnum): DataProvider {
return radiusListFetch;
case OutpostTypeEnum.Rac:
return racListProvider;
case OutpostTypeEnum.Scim:
return scimListProvider;
default:
throw new Error(`Unrecognized OutputType: ${type}`);
}
@ -142,6 +147,7 @@ export class OutpostForm extends ModelForm<Outpost, string> {
[OutpostTypeEnum.Ldap, msg("LDAP")],
[OutpostTypeEnum.Radius, msg("Radius")],
[OutpostTypeEnum.Rac, msg("RAC")],
[OutpostTypeEnum.Scim, msg("SCIM")],
];
return html` <ak-form-element-horizontal label=${msg("Name")} ?required=${true} name="name">