Compare commits
150 Commits
version/20
...
version/20
Author | SHA1 | Date | |
---|---|---|---|
ab2b13938e | |||
5c97a3aef3 | |||
e6963c543d | |||
9ca15983a2 | |||
99ef94b7aa | |||
133bedafba | |||
c3faa61ed9 | |||
da74304221 | |||
ed6659a46d | |||
0abb1f94a4 | |||
c7e299e0bf | |||
8a6590bac8 | |||
ed717dcfa2 | |||
b6df42f580 | |||
2ea85bd0c4 | |||
68fa8105e1 | |||
79db0ce4c1 | |||
5e23b11764 | |||
c4e029ffe2 | |||
61b5b36192 | |||
c6cc1b1728 | |||
77dd652160 | |||
1144944adb | |||
7751be284e | |||
74382c6287 | |||
011babbbd9 | |||
3c01a1dd7b | |||
6e832be2de | |||
46017f2f86 | |||
da50eb0369 | |||
b996e3cee7 | |||
12735cc14c | |||
4d36699b78 | |||
8110d2861b | |||
1cc60f572d | |||
90151a13ae | |||
f958aa6930 | |||
13fbac30a2 | |||
4f4cdf16f1 | |||
7d75599627 | |||
924a13e832 | |||
ae83c35dfd | |||
e9102f4e28 | |||
9b8c1cbea5 | |||
6424bf98da | |||
74fb0f9e2a | |||
4380f37a77 | |||
17fccd44e6 | |||
217a8b5610 | |||
2cef220a3e | |||
5a8c66d325 | |||
8de13d3f67 | |||
5c22bedbaf | |||
8a0f993f0b | |||
abcf515a69 | |||
894f704c27 | |||
7798292aa8 | |||
3005ca17bd | |||
909461e533 | |||
df838a4023 | |||
0f86b62dd3 | |||
a40c3aeb68 | |||
4080738ded | |||
4a89be3048 | |||
e587c53e18 | |||
023b97aa69 | |||
51365dba74 | |||
0d3705685e | |||
738e4d5c74 | |||
b14b9cb0dd | |||
2a21ebf7b0 | |||
5bc1301043 | |||
e0e4bf6972 | |||
337677ad12 | |||
3712d5aee2 | |||
dd82d55725 | |||
8d766efecb | |||
9ac3b29418 | |||
5000c5b061 | |||
b362d2af03 | |||
bcd42fce13 | |||
6deddd038f | |||
3b47cb64da | |||
cf5e70c759 | |||
20bc38a54b | |||
672a4ab1f4 | |||
47dd667261 | |||
d1ac69789b | |||
08abf81c6d | |||
76bd987e6f | |||
5374352411 | |||
08eff4cc5d | |||
c87a9f9489 | |||
8f6d700aa8 | |||
c6843b026c | |||
3769c33ef0 | |||
8982afaf44 | |||
58c221e867 | |||
108d3e56e3 | |||
145b32c480 | |||
c788504bb0 | |||
34782b31e5 | |||
5a3ca13d76 | |||
5dc0f3b91b | |||
f51515f3de | |||
f978575293 | |||
cb64eed90d | |||
db1f7f0400 | |||
0d02dbf55c | |||
6da78b8c32 | |||
3a80bc8bda | |||
1aa9c0f9ca | |||
2da7a8fede | |||
89cb402f42 | |||
b617fd213f | |||
97b0f58f25 | |||
49a98bb744 | |||
f93a00d773 | |||
8de40a8a21 | |||
b9c54e97fa | |||
f1c55465f7 | |||
40c2b2860b | |||
a92bce322d | |||
af83308fd4 | |||
73d991e75a | |||
1eba3f1334 | |||
b86251255d | |||
ccab41a6ca | |||
0e051031b1 | |||
aecbe8c585 | |||
da98022704 | |||
e13f9c0b38 | |||
7941fb9d95 | |||
d2392b0881 | |||
b2044d75fb | |||
617b64b7db | |||
2bf5f2709a | |||
f03325df28 | |||
2b71e5bdfd | |||
f861737b85 | |||
6036d88392 | |||
bfc8a56a0b | |||
8d995011b8 | |||
5646141fe2 | |||
96b0bc324e | |||
335d6edd11 | |||
5d9bed130a | |||
0a1ab74707 | |||
ef24b94585 | |||
77b0438aa4 |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2021.10.1
|
||||
current_version = 2021.10.3
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)\-?(?P<release>.*)
|
||||
|
2
.github/workflows/ci-main.yml
vendored
2
.github/workflows/ci-main.yml
vendored
@ -147,6 +147,8 @@ jobs:
|
||||
# Copy current, latest config to local
|
||||
cp authentik/lib/default.yml local.env.yml
|
||||
git checkout $(git describe --abbrev=0 --match 'version/*')
|
||||
git checkout ${{ steps.ev.outputs.branchName }} -- .github
|
||||
git checkout ${{ steps.ev.outputs.branchName }} -- scripts
|
||||
- name: prepare
|
||||
env:
|
||||
INSTALL: ${{ steps.cache-pipenv.outputs.cache-hit }}
|
||||
|
20
.github/workflows/release-publish.yml
vendored
20
.github/workflows/release-publish.yml
vendored
@ -30,14 +30,14 @@ jobs:
|
||||
with:
|
||||
push: ${{ github.event_name == 'release' }}
|
||||
tags: |
|
||||
beryju/authentik:2021.10.1,
|
||||
beryju/authentik:2021.10.3,
|
||||
beryju/authentik:latest,
|
||||
ghcr.io/goauthentik/server:2021.10.1,
|
||||
ghcr.io/goauthentik/server:2021.10.3,
|
||||
ghcr.io/goauthentik/server:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
context: .
|
||||
- name: Building Docker Image (stable)
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.10.1', 'rc') }}
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.10.3', 'rc') }}
|
||||
run: |
|
||||
docker pull beryju/authentik:latest
|
||||
docker tag beryju/authentik:latest beryju/authentik:stable
|
||||
@ -72,14 +72,14 @@ jobs:
|
||||
with:
|
||||
push: ${{ github.event_name == 'release' }}
|
||||
tags: |
|
||||
beryju/authentik-proxy:2021.10.1,
|
||||
beryju/authentik-proxy:2021.10.3,
|
||||
beryju/authentik-proxy:latest,
|
||||
ghcr.io/goauthentik/proxy:2021.10.1,
|
||||
ghcr.io/goauthentik/proxy:2021.10.3,
|
||||
ghcr.io/goauthentik/proxy:latest
|
||||
file: proxy.Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- name: Building Docker Image (stable)
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.10.1', 'rc') }}
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.10.3', 'rc') }}
|
||||
run: |
|
||||
docker pull beryju/authentik-proxy:latest
|
||||
docker tag beryju/authentik-proxy:latest beryju/authentik-proxy:stable
|
||||
@ -114,14 +114,14 @@ jobs:
|
||||
with:
|
||||
push: ${{ github.event_name == 'release' }}
|
||||
tags: |
|
||||
beryju/authentik-ldap:2021.10.1,
|
||||
beryju/authentik-ldap:2021.10.3,
|
||||
beryju/authentik-ldap:latest,
|
||||
ghcr.io/goauthentik/ldap:2021.10.1,
|
||||
ghcr.io/goauthentik/ldap:2021.10.3,
|
||||
ghcr.io/goauthentik/ldap:latest
|
||||
file: ldap.Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- name: Building Docker Image (stable)
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.10.1', 'rc') }}
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.10.3', 'rc') }}
|
||||
run: |
|
||||
docker pull beryju/authentik-ldap:latest
|
||||
docker tag beryju/authentik-ldap:latest beryju/authentik-ldap:stable
|
||||
@ -170,7 +170,7 @@ jobs:
|
||||
SENTRY_PROJECT: authentik
|
||||
SENTRY_URL: https://sentry.beryju.org
|
||||
with:
|
||||
version: authentik@2021.10.1
|
||||
version: authentik@2021.10.3
|
||||
environment: beryjuorg-prod
|
||||
sourcemaps: './web/dist'
|
||||
url_prefix: '~/static/dist'
|
||||
|
10
Dockerfile
10
Dockerfile
@ -1,12 +1,12 @@
|
||||
# Stage 1: Lock python dependencies
|
||||
FROM docker.io/python:3.9-slim-buster as locker
|
||||
FROM docker.io/python:3.9-bullseye as locker
|
||||
|
||||
COPY ./Pipfile /app/
|
||||
COPY ./Pipfile.lock /app/
|
||||
|
||||
WORKDIR /app/
|
||||
|
||||
RUN pip install pipenv && \
|
||||
RUN pip install pipenv==2021.5.29 && \
|
||||
pipenv lock -r > requirements.txt && \
|
||||
pipenv lock -r --dev-only > requirements-dev.txt
|
||||
|
||||
@ -27,7 +27,7 @@ ENV NODE_ENV=production
|
||||
RUN cd /static && npm i && npm run build
|
||||
|
||||
# Stage 4: Build go proxy
|
||||
FROM docker.io/golang:1.17.2 AS builder
|
||||
FROM docker.io/golang:1.17.3-bullseye AS builder
|
||||
|
||||
WORKDIR /work
|
||||
|
||||
@ -47,7 +47,7 @@ COPY ./go.sum /work/go.sum
|
||||
RUN go build -o /work/authentik ./cmd/server/main.go
|
||||
|
||||
# Stage 5: Run
|
||||
FROM docker.io/python:3.9-slim-buster
|
||||
FROM docker.io/python:3.9-bullseye
|
||||
|
||||
WORKDIR /
|
||||
COPY --from=locker /app/requirements.txt /
|
||||
@ -59,7 +59,7 @@ ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git runit && \
|
||||
curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
|
||||
echo "deb http://apt.postgresql.org/pub/repos/apt buster-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \
|
||||
echo "deb http://apt.postgresql.org/pub/repos/apt bullseye-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends libpq-dev postgresql-client build-essential libxmlsec1-dev pkg-config libmaxminddb0 && \
|
||||
pip install -r /requirements.txt --no-cache-dir && \
|
||||
|
14
Makefile
14
Makefile
@ -30,7 +30,6 @@ lint-fix:
|
||||
website/developer-docs
|
||||
|
||||
lint:
|
||||
pyright authentik tests lifecycle
|
||||
bandit -r authentik tests lifecycle -x node_modules
|
||||
pylint authentik tests lifecycle
|
||||
|
||||
@ -49,7 +48,7 @@ gen-web:
|
||||
docker run \
|
||||
--rm -v ${PWD}:/local \
|
||||
--user ${UID}:${GID} \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
ghcr.io/beryju/openapi-generator generate \
|
||||
-i /local/schema.yml \
|
||||
-g typescript-fetch \
|
||||
-o /local/web-api \
|
||||
@ -61,18 +60,19 @@ gen-web:
|
||||
\cp -rfv web-api/* web/node_modules/@goauthentik/api
|
||||
|
||||
gen-outpost:
|
||||
wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O config.yaml
|
||||
mkdir -p templates
|
||||
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O templates/README.mustache
|
||||
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O templates/go.mod.mustache
|
||||
docker run \
|
||||
--rm -v ${PWD}:/local \
|
||||
--user ${UID}:${GID} \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
--git-host goauthentik.io \
|
||||
--git-repo-id outpost \
|
||||
--git-user-id api \
|
||||
-i /local/schema.yml \
|
||||
-g go \
|
||||
-o /local/api \
|
||||
--additional-properties=packageName=api,enumClassPrefix=true,useOneOfDiscriminatorLookup=true,disallowAdditionalPropertiesIfNotPresent=false
|
||||
rm -f api/go.mod api/go.sum
|
||||
-c /local/config.yaml
|
||||
go mod edit -replace goauthentik.io/api=./api
|
||||
|
||||
gen: gen-build gen-clean gen-web
|
||||
|
||||
|
885
Pipfile.lock
generated
885
Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,3 @@
|
||||
"""authentik"""
|
||||
__version__ = "2021.10.1"
|
||||
__version__ = "2021.10.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
@ -27,6 +27,7 @@ VERSION_CACHE_TIMEOUT = 8 * 60 * 60 # 8 hours
|
||||
# Chop of the first ^ because we want to search the entire string
|
||||
URL_FINDER = URLValidator.regex.pattern[1:]
|
||||
PROM_INFO = Info("authentik_version", "Currently running authentik version")
|
||||
LOCAL_VERSION = parse(__version__)
|
||||
|
||||
|
||||
def _set_prom_info():
|
||||
@ -48,7 +49,7 @@ def clear_update_notifications():
|
||||
if "new_version" not in notification.event.context:
|
||||
continue
|
||||
notification_version = notification.event.context["new_version"]
|
||||
if notification_version == __version__:
|
||||
if LOCAL_VERSION >= parse(notification_version):
|
||||
notification.delete()
|
||||
|
||||
|
||||
@ -74,8 +75,7 @@ def update_latest_version(self: MonitoredTask):
|
||||
_set_prom_info()
|
||||
# Check if upstream version is newer than what we're running,
|
||||
# and if no event exists yet, create one.
|
||||
local_version = parse(__version__)
|
||||
if local_version < parse(upstream_version):
|
||||
if LOCAL_VERSION < parse(upstream_version):
|
||||
# Event has already been created, don't create duplicate
|
||||
if Event.objects.filter(
|
||||
action=EventAction.UPDATE_AVAILABLE,
|
||||
|
@ -1,19 +0,0 @@
|
||||
"""API tasks"""
|
||||
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
SENTRY_SESSION = get_http_session()
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def sentry_proxy(payload: str):
|
||||
"""Relay data to sentry"""
|
||||
SENTRY_SESSION.post(
|
||||
"https://sentry.beryju.org/api/8/envelope/",
|
||||
data=payload,
|
||||
headers={
|
||||
"Content-Type": "application/octet-stream",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
@ -4,7 +4,7 @@ from django.urls import include, path
|
||||
from authentik.api.v3.urls import urlpatterns as v3_urls
|
||||
|
||||
urlpatterns = [
|
||||
# Remove in 2022.1
|
||||
# TODO: Remove in 2022.1
|
||||
path("v2beta/", include(v3_urls)),
|
||||
path("v3/", include(v3_urls)),
|
||||
]
|
||||
|
@ -1,65 +0,0 @@
|
||||
"""Sentry tunnel"""
|
||||
from json import loads
|
||||
|
||||
from django.conf import settings
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponse
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework.parsers import BaseParser
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.throttling import AnonRateThrottle
|
||||
from rest_framework.views import APIView
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.tasks import sentry_proxy
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class PlainTextParser(BaseParser):
|
||||
"""Plain text parser."""
|
||||
|
||||
media_type = "text/plain"
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None) -> str:
|
||||
"""Simply return a string representing the body of the request."""
|
||||
return stream.read()
|
||||
|
||||
|
||||
class CsrfExemptSessionAuthentication(SessionAuthentication):
|
||||
"""CSRF-exempt Session authentication"""
|
||||
|
||||
def enforce_csrf(self, request: Request):
|
||||
return # To not perform the csrf check previously happening
|
||||
|
||||
|
||||
class SentryTunnelView(APIView):
|
||||
"""Sentry tunnel, to prevent ad blockers from blocking sentry"""
|
||||
|
||||
serializer_class = None
|
||||
parser_classes = [PlainTextParser]
|
||||
throttle_classes = [AnonRateThrottle]
|
||||
permission_classes = [AllowAny]
|
||||
authentication_classes = [CsrfExemptSessionAuthentication]
|
||||
|
||||
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Sentry tunnel, to prevent ad blockers from blocking sentry"""
|
||||
# Only allow usage of this endpoint when error reporting is enabled
|
||||
if not CONFIG.y_bool("error_reporting.enabled", False):
|
||||
LOGGER.debug("error reporting disabled")
|
||||
return HttpResponse(status=400)
|
||||
# Body is 2 json objects separated by \n
|
||||
full_body = request.body
|
||||
lines = full_body.splitlines()
|
||||
if len(lines) < 1:
|
||||
return HttpResponse(status=400)
|
||||
header = loads(lines[0])
|
||||
# Check that the DSN is what we expect
|
||||
dsn = header.get("dsn", "")
|
||||
if dsn != settings.SENTRY_DSN:
|
||||
LOGGER.debug("Invalid dsn", have=dsn, expected=settings.SENTRY_DSN)
|
||||
return HttpResponse(status=400)
|
||||
sentry_proxy.delay(full_body.decode())
|
||||
return HttpResponse(status=204)
|
@ -11,7 +11,6 @@ from authentik.admin.api.tasks import TaskViewSet
|
||||
from authentik.admin.api.version import VersionView
|
||||
from authentik.admin.api.workers import WorkerView
|
||||
from authentik.api.v3.config import ConfigView
|
||||
from authentik.api.v3.sentry import SentryTunnelView
|
||||
from authentik.api.views import APIBrowserView
|
||||
from authentik.core.api.applications import ApplicationViewSet
|
||||
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
|
||||
@ -249,7 +248,6 @@ urlpatterns = (
|
||||
FlowInspectorView.as_view(),
|
||||
name="flow-inspector",
|
||||
),
|
||||
path("sentry/", SentryTunnelView.as_view(), name="sentry"),
|
||||
path("schema/", cache_page(86400)(SpectacularAPIView.as_view()), name="schema"),
|
||||
]
|
||||
)
|
||||
|
@ -42,6 +42,7 @@ class GroupSerializer(ModelSerializer):
|
||||
users_obj = ListSerializer(
|
||||
child=GroupMemberSerializer(), read_only=True, source="users", required=False
|
||||
)
|
||||
parent_name = CharField(source="parent.name", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
@ -51,6 +52,7 @@ class GroupSerializer(ModelSerializer):
|
||||
"name",
|
||||
"is_superuser",
|
||||
"parent",
|
||||
"parent_name",
|
||||
"users",
|
||||
"attributes",
|
||||
"users_obj",
|
||||
|
@ -55,5 +55,5 @@ class TokenBackend(InbuiltBackend):
|
||||
if not tokens.exists():
|
||||
return None
|
||||
token = tokens.first()
|
||||
self.set_method("password", request, token=token)
|
||||
self.set_method("token", request, token=token)
|
||||
return token.user
|
||||
|
0
authentik/core/management/__init__.py
Normal file
0
authentik/core/management/__init__.py
Normal file
0
authentik/core/management/commands/__init__.py
Normal file
0
authentik/core/management/commands/__init__.py
Normal file
15
authentik/core/management/commands/dump_config.py
Normal file
15
authentik/core/management/commands/dump_config.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Output full config"""
|
||||
from json import dumps
|
||||
|
||||
from django.core.management.base import BaseCommand, no_translations
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
class Command(BaseCommand): # pragma: no cover
|
||||
"""Output full config"""
|
||||
|
||||
@no_translations
|
||||
def handle(self, *args, **options):
|
||||
"""Check permissions for all apps"""
|
||||
print(dumps(CONFIG.raw, indent=4))
|
@ -81,6 +81,27 @@ class Group(models.Model):
|
||||
)
|
||||
attributes = models.JSONField(default=dict, blank=True)
|
||||
|
||||
def is_member(self, user: "User") -> bool:
|
||||
"""Recursively check if `user` is member of us, or any parent."""
|
||||
query = """
|
||||
WITH RECURSIVE parents AS (
|
||||
SELECT authentik_core_group.*, 0 AS relative_depth
|
||||
FROM authentik_core_group
|
||||
WHERE authentik_core_group.group_uuid = %s
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT authentik_core_group.*, parents.relative_depth - 1
|
||||
FROM authentik_core_group,parents
|
||||
WHERE authentik_core_group.parent_id = parents.group_uuid
|
||||
)
|
||||
SELECT group_uuid
|
||||
FROM parents
|
||||
GROUP BY group_uuid;
|
||||
"""
|
||||
groups = Group.objects.raw(query, [self.group_uuid])
|
||||
return user.ak_groups.filter(pk__in=[group.pk for group in groups]).exists()
|
||||
|
||||
def __str__(self):
|
||||
return f"Group {self.name}"
|
||||
|
||||
|
40
authentik/core/tests/test_groups.py
Normal file
40
authentik/core/tests/test_groups.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""group tests"""
|
||||
from django.test.testcases import TestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
|
||||
|
||||
class TestGroups(TestCase):
|
||||
"""Test group membership"""
|
||||
|
||||
def test_group_membership_simple(self):
|
||||
"""Test simple membership"""
|
||||
user = User.objects.create(username="user")
|
||||
user2 = User.objects.create(username="user2")
|
||||
group = Group.objects.create(name="group")
|
||||
group.users.add(user)
|
||||
self.assertTrue(group.is_member(user))
|
||||
self.assertFalse(group.is_member(user2))
|
||||
|
||||
def test_group_membership_parent(self):
|
||||
"""Test parent membership"""
|
||||
user = User.objects.create(username="user")
|
||||
user2 = User.objects.create(username="user2")
|
||||
first = Group.objects.create(name="first")
|
||||
second = Group.objects.create(name="second", parent=first)
|
||||
second.users.add(user)
|
||||
self.assertTrue(first.is_member(user))
|
||||
self.assertFalse(first.is_member(user2))
|
||||
|
||||
def test_group_membership_parent_extra(self):
|
||||
"""Test parent membership"""
|
||||
user = User.objects.create(username="user")
|
||||
user2 = User.objects.create(username="user2")
|
||||
first = Group.objects.create(name="first")
|
||||
second = Group.objects.create(name="second", parent=first)
|
||||
third = Group.objects.create(name="third", parent=second)
|
||||
second.users.add(user)
|
||||
self.assertTrue(first.is_member(user))
|
||||
self.assertFalse(first.is_member(user2))
|
||||
self.assertFalse(third.is_member(user))
|
||||
self.assertFalse(third.is_member(user2))
|
@ -7,16 +7,25 @@ from django.core.exceptions import SuspiciousOperation
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import post_save, pre_delete
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django_otp.plugins.otp_static.models import StaticToken
|
||||
from guardian.models import UserObjectPermission
|
||||
|
||||
from authentik.core.middleware import LOCAL
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.events.models import Event, EventAction, Notification
|
||||
from authentik.events.signals import EventNewThread
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.lib.sentry import before_send
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
IGNORED_MODELS = (
|
||||
Event,
|
||||
Notification,
|
||||
UserObjectPermission,
|
||||
AuthenticatedSession,
|
||||
StaticToken,
|
||||
)
|
||||
|
||||
|
||||
class AuditMiddleware:
|
||||
"""Register handlers for duration of request-response that log creation/update/deletion
|
||||
@ -82,7 +91,7 @@ class AuditMiddleware:
|
||||
user: User, request: HttpRequest, sender, instance: Model, created: bool, **_
|
||||
):
|
||||
"""Signal handler for all object's post_save"""
|
||||
if isinstance(instance, (Event, Notification, UserObjectPermission)):
|
||||
if isinstance(instance, IGNORED_MODELS):
|
||||
return
|
||||
|
||||
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
|
||||
@ -92,7 +101,7 @@ class AuditMiddleware:
|
||||
# pylint: disable=unused-argument
|
||||
def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_):
|
||||
"""Signal handler for all object's pre_delete"""
|
||||
if isinstance(instance, (Event, Notification, UserObjectPermission)): # pragma: no cover
|
||||
if isinstance(instance, IGNORED_MODELS): # pragma: no cover
|
||||
return
|
||||
|
||||
EventNewThread(
|
||||
|
@ -545,6 +545,7 @@ class TestFlowExecutor(APITestCase):
|
||||
"password_fields": False,
|
||||
"primary_action": "Log in",
|
||||
"sources": [],
|
||||
"show_source_labels": False,
|
||||
"user_fields": [UserFields.E_MAIL],
|
||||
},
|
||||
)
|
||||
|
@ -60,6 +60,7 @@ class TestFlowInspector(APITestCase):
|
||||
"password_fields": False,
|
||||
"primary_action": "Log in",
|
||||
"sources": [],
|
||||
"show_source_labels": False,
|
||||
"user_fields": ["username"],
|
||||
},
|
||||
)
|
||||
|
@ -68,6 +68,7 @@ outposts:
|
||||
|
||||
cookie_domain: null
|
||||
disable_update_check: false
|
||||
disable_startup_analytics: false
|
||||
avatars: env://AUTHENTIK_AUTHENTIK__AVATARS?gravatar
|
||||
geoip: "./GeoLite2-City.mmdb"
|
||||
|
||||
|
@ -13,6 +13,7 @@ from django.db import InternalError, OperationalError, ProgrammingError
|
||||
from django.http.response import Http404
|
||||
from django_redis.exceptions import ConnectionInterrupted
|
||||
from docker.errors import DockerException
|
||||
from h11 import LocalProtocolError
|
||||
from ldap3.core.exceptions import LDAPException
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
from redis.exceptions import RedisError, ResponseError
|
||||
@ -72,6 +73,7 @@ def before_send(event: dict, hint: dict) -> Optional[dict]:
|
||||
# websocket errors
|
||||
ChannelFull,
|
||||
WebSocketException,
|
||||
LocalProtocolError,
|
||||
# rest_framework error
|
||||
APIException,
|
||||
# celery errors
|
||||
|
@ -1,8 +1,13 @@
|
||||
"""authentik lib reflection utilities"""
|
||||
import os
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from django.conf import settings
|
||||
from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
def all_subclasses(cls, sort=True):
|
||||
@ -42,3 +47,16 @@ def get_apps():
|
||||
for _app in apps.get_app_configs():
|
||||
if _app.name.startswith("authentik"):
|
||||
yield _app
|
||||
|
||||
|
||||
def get_env() -> str:
|
||||
"""Get environment in which authentik is currently running"""
|
||||
if SERVICE_HOST_ENV_NAME in os.environ:
|
||||
return "kubernetes"
|
||||
if "CI" in os.environ:
|
||||
return "ci"
|
||||
if Path("/tmp/authentik-mode").exists(): # nosec
|
||||
return "compose"
|
||||
if CONFIG.y_bool("debug"):
|
||||
return "dev"
|
||||
return "custom"
|
||||
|
@ -65,14 +65,14 @@ class PolicyBinding(SerializerModel):
|
||||
# This is quite an ugly hack to prevent pylint from trying
|
||||
# to resolve authentik_core.models.Group
|
||||
# as python import path
|
||||
"authentik_core." + "Group",
|
||||
"authentik_core.Group",
|
||||
on_delete=models.CASCADE,
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
user = models.ForeignKey(
|
||||
"authentik_core." + "User",
|
||||
"authentik_core.User",
|
||||
on_delete=models.CASCADE,
|
||||
default=None,
|
||||
null=True,
|
||||
@ -96,7 +96,7 @@ class PolicyBinding(SerializerModel):
|
||||
self.policy: Policy
|
||||
return self.policy.passes(request)
|
||||
if self.group:
|
||||
return PolicyResult(self.group.users.filter(pk=request.user.pk).exists())
|
||||
return PolicyResult(self.group.is_member(request.user))
|
||||
if self.user:
|
||||
return PolicyResult(request.user == self.user)
|
||||
return PolicyResult(False)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""LDAPProvider API Views"""
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.fields import CharField, ListField
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||
|
||||
@ -11,6 +11,8 @@ from authentik.providers.ldap.models import LDAPProvider
|
||||
class LDAPProviderSerializer(ProviderSerializer):
|
||||
"""LDAPProvider Serializer"""
|
||||
|
||||
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
|
||||
|
||||
class Meta:
|
||||
|
||||
model = LDAPProvider
|
||||
@ -21,6 +23,8 @@ class LDAPProviderSerializer(ProviderSerializer):
|
||||
"tls_server_name",
|
||||
"uid_start_number",
|
||||
"gid_start_number",
|
||||
"outpost_set",
|
||||
"search_mode",
|
||||
]
|
||||
|
||||
|
||||
@ -65,6 +69,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
|
||||
"tls_server_name",
|
||||
"uid_start_number",
|
||||
"gid_start_number",
|
||||
"search_mode",
|
||||
]
|
||||
|
||||
|
||||
|
@ -0,0 +1,93 @@
|
||||
# Generated by Django 3.2.8 on 2021-11-05 09:41
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
replaces = [
|
||||
("authentik_providers_ldap", "0001_initial"),
|
||||
("authentik_providers_ldap", "0002_ldapprovider_search_group"),
|
||||
("authentik_providers_ldap", "0003_auto_20210713_1138"),
|
||||
("authentik_providers_ldap", "0004_auto_20210713_2115"),
|
||||
("authentik_providers_ldap", "0005_ldapprovider_search_mode"),
|
||||
]
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0019_source_managed"),
|
||||
("authentik_crypto", "0002_create_self_signed_kp"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="LDAPProvider",
|
||||
fields=[
|
||||
(
|
||||
"provider_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.provider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"base_dn",
|
||||
models.TextField(
|
||||
default="DC=ldap,DC=goauthentik,DC=io",
|
||||
help_text="DN under which objects are accessible.",
|
||||
),
|
||||
),
|
||||
(
|
||||
"search_group",
|
||||
models.ForeignKey(
|
||||
default=None,
|
||||
help_text="Users in this group can do search queries. If not set, every user can execute search queries.",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_DEFAULT,
|
||||
to="authentik_core.group",
|
||||
),
|
||||
),
|
||||
(
|
||||
"certificate",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="authentik_crypto.certificatekeypair",
|
||||
),
|
||||
),
|
||||
("tls_server_name", models.TextField(blank=True, default="")),
|
||||
(
|
||||
"gid_start_number",
|
||||
models.IntegerField(
|
||||
default=4000,
|
||||
help_text="The start for gidNumbers, this number is added to a number generated from the group.Pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber",
|
||||
),
|
||||
),
|
||||
(
|
||||
"uid_start_number",
|
||||
models.IntegerField(
|
||||
default=2000,
|
||||
help_text="The start for uidNumbers, this number is added to the user.Pk to make sure that the numbers aren't too low for POSIX users. Default is 2000 to ensure that we don't collide with local users uidNumber",
|
||||
),
|
||||
),
|
||||
(
|
||||
"search_mode",
|
||||
models.TextField(
|
||||
choices=[("direct", "Direct"), ("cached", "Cached")], default="direct"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "LDAP Provider",
|
||||
"verbose_name_plural": "LDAP Providers",
|
||||
},
|
||||
bases=("authentik_core.provider", models.Model),
|
||||
),
|
||||
]
|
@ -0,0 +1,20 @@
|
||||
# Generated by Django 3.2.8 on 2021-11-05 09:40
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_ldap", "0004_auto_20210713_2115"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="ldapprovider",
|
||||
name="search_mode",
|
||||
field=models.TextField(
|
||||
choices=[("direct", "Direct"), ("cached", "Cached")], default="direct"
|
||||
),
|
||||
),
|
||||
]
|
@ -10,6 +10,13 @@ from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.outposts.models import OutpostModel
|
||||
|
||||
|
||||
class SearchModes(models.TextChoices):
|
||||
"""Search modes"""
|
||||
|
||||
DIRECT = "direct"
|
||||
CACHED = "cached"
|
||||
|
||||
|
||||
class LDAPProvider(OutpostModel, Provider):
|
||||
"""Allow applications to authenticate against authentik's users using LDAP."""
|
||||
|
||||
@ -59,6 +66,8 @@ class LDAPProvider(OutpostModel, Provider):
|
||||
),
|
||||
)
|
||||
|
||||
search_mode = models.TextField(default=SearchModes.DIRECT, choices=SearchModes.choices)
|
||||
|
||||
@property
|
||||
def launch_url(self) -> Optional[str]:
|
||||
"""LDAP never has a launch URL"""
|
||||
|
@ -36,6 +36,7 @@ class ProxyProviderSerializer(ProviderSerializer):
|
||||
"""ProxyProvider Serializer"""
|
||||
|
||||
redirect_uris = CharField(read_only=True)
|
||||
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
|
||||
|
||||
def validate(self, attrs) -> dict[Any, str]:
|
||||
"""Check that internal_host is set when mode is Proxy"""
|
||||
@ -74,6 +75,7 @@ class ProxyProviderSerializer(ProviderSerializer):
|
||||
"redirect_uris",
|
||||
"cookie_domain",
|
||||
"token_validity",
|
||||
"outpost_set",
|
||||
]
|
||||
|
||||
|
||||
|
@ -138,7 +138,7 @@ class ProxyProvider(OutpostModel, OAuth2Provider):
|
||||
SCOPE_AK_PROXY,
|
||||
]
|
||||
)
|
||||
self.property_mappings.set(scopes)
|
||||
self.property_mappings.add(*list(scopes))
|
||||
self.redirect_uris = _get_callback_url(self.external_host)
|
||||
|
||||
def __str__(self):
|
||||
|
@ -59,11 +59,13 @@ class AuthNRequestParser:
|
||||
) -> AuthNRequest:
|
||||
root = ElementTree.fromstring(decoded_xml)
|
||||
|
||||
# http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf
|
||||
# `AssertionConsumerServiceURL` can be omitted, and we should fallback to the
|
||||
# default ACS URL
|
||||
if "AssertionConsumerServiceURL" not in root.attrib:
|
||||
msg = "Missing 'AssertionConsumerServiceURL' attribute"
|
||||
LOGGER.warning(msg)
|
||||
raise CannotHandleAssertion(msg)
|
||||
request_acs_url = root.attrib["AssertionConsumerServiceURL"]
|
||||
request_acs_url = self.provider.acs_url.lower()
|
||||
else:
|
||||
request_acs_url = root.attrib["AssertionConsumerServiceURL"]
|
||||
|
||||
if self.provider.acs_url.lower() != request_acs_url.lower():
|
||||
msg = (
|
||||
|
30
authentik/recovery/management/commands/create_admin_group.py
Normal file
30
authentik/recovery/management/commands/create_admin_group.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""authentik recovery create_admin_group"""
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Create admin group if the default group gets deleted"""
|
||||
|
||||
help = _("Create admin group if the default group gets deleted.")
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("user", action="store", help="User to add to the admin group.")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
"""Create admin group if the default group gets deleted"""
|
||||
username = options.get("user")
|
||||
user = User.objects.filter(username=username).first()
|
||||
if not user:
|
||||
self.stderr.write(f"User '{username}' not found.")
|
||||
return
|
||||
group, _ = Group.objects.update_or_create(
|
||||
name="authentik Admins",
|
||||
defaults={
|
||||
"is_superuser": True,
|
||||
},
|
||||
)
|
||||
group.users.add(user)
|
||||
self.stdout.write(f"User '{username}' successfully added to the group 'authentik Admins'.")
|
@ -7,12 +7,9 @@ from django.urls import reverse
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Token, TokenIntents, User
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Create Token used to recover access"""
|
||||
|
@ -14,13 +14,13 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from hashlib import sha512
|
||||
from json import dumps
|
||||
from tempfile import gettempdir
|
||||
from time import time
|
||||
|
||||
import structlog
|
||||
from celery.schedules import crontab
|
||||
from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME
|
||||
from sentry_sdk import init as sentry_init
|
||||
from sentry_sdk.api import set_tag
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
@ -32,6 +32,8 @@ from authentik.core.middleware import structlog_add_request_id
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.logging import add_process_id
|
||||
from authentik.lib.sentry import before_send
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.reflection import get_env
|
||||
from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP
|
||||
|
||||
|
||||
@ -176,6 +178,7 @@ SPECTACULAR_SETTINGS = {
|
||||
"FlowDesignationEnum": "authentik.flows.models.FlowDesignation",
|
||||
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
||||
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
},
|
||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||
"POSTPROCESSING_HOOKS": [
|
||||
@ -307,7 +310,7 @@ EMAIL_HOST = CONFIG.y("email.host")
|
||||
EMAIL_PORT = int(CONFIG.y("email.port"))
|
||||
EMAIL_HOST_USER = CONFIG.y("email.username")
|
||||
EMAIL_HOST_PASSWORD = CONFIG.y("email.password")
|
||||
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", True)
|
||||
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False)
|
||||
EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False)
|
||||
EMAIL_TIMEOUT = int(CONFIG.y("email.timeout"))
|
||||
DEFAULT_FROM_EMAIL = CONFIG.y("email.from")
|
||||
@ -380,7 +383,8 @@ DBBACKUP_CONNECTOR_MAPPING = {
|
||||
"django_prometheus.db.backends.postgresql": "dbbackup.db.postgresql.PgDumpConnector",
|
||||
}
|
||||
DBBACKUP_TMP_DIR = gettempdir() if DEBUG else "/tmp" # nosec
|
||||
if CONFIG.y("postgresql.s3_backup"):
|
||||
DBBACKUP_CLEANUP_KEEP = 30
|
||||
if CONFIG.y("postgresql.s3_backup.bucket", "") != "":
|
||||
DBBACKUP_STORAGE = "storages.backends.s3boto3.S3Boto3Storage"
|
||||
DBBACKUP_STORAGE_OPTIONS = {
|
||||
"access_key": CONFIG.y("postgresql.s3_backup.access_key"),
|
||||
@ -399,6 +403,12 @@ if CONFIG.y("postgresql.s3_backup"):
|
||||
|
||||
# Sentry integration
|
||||
SENTRY_DSN = "https://a579bb09306d4f8b8d8847c052d3a1d3@sentry.beryju.org/8"
|
||||
# Default to empty string as that is what docker has
|
||||
build_hash = os.environ.get(ENV_GIT_HASH_KEY, "")
|
||||
if build_hash == "":
|
||||
build_hash = "tagged"
|
||||
|
||||
env = get_env()
|
||||
_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False)
|
||||
if _ERROR_REPORTING:
|
||||
# pylint: disable=abstract-class-instantiated
|
||||
@ -415,18 +425,29 @@ if _ERROR_REPORTING:
|
||||
environment=CONFIG.y("error_reporting.environment", "customer"),
|
||||
send_default_pii=CONFIG.y_bool("error_reporting.send_pii", False),
|
||||
)
|
||||
# Default to empty string as that is what docker has
|
||||
build_hash = os.environ.get(ENV_GIT_HASH_KEY, "")
|
||||
if build_hash == "":
|
||||
build_hash = "tagged"
|
||||
set_tag("authentik.build_hash", build_hash)
|
||||
set_tag("authentik.env", "kubernetes" if SERVICE_HOST_ENV_NAME in os.environ else "compose")
|
||||
set_tag("authentik.env", env)
|
||||
set_tag("authentik.component", "backend")
|
||||
j_print(
|
||||
"Error reporting is enabled",
|
||||
env=CONFIG.y("error_reporting.environment", "customer"),
|
||||
)
|
||||
|
||||
if not CONFIG.y_bool("disable_startup_analytics", False):
|
||||
should_send = env not in ["dev", "ci"]
|
||||
if should_send:
|
||||
get_http_session().post(
|
||||
"https://goauthentik.io/api/event",
|
||||
json={
|
||||
"domain": "authentik",
|
||||
"name": "pageview",
|
||||
"url": f"http://localhost/{env}",
|
||||
"referrer": f"{__version__} ({build_hash})",
|
||||
},
|
||||
headers={
|
||||
"User-Agent": sha512(SECRET_KEY.encode("ascii")).hexdigest()[:16],
|
||||
"Content-Type": "text/plain",
|
||||
},
|
||||
)
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/2.1/howto/static-files/
|
||||
|
@ -1,6 +1,4 @@
|
||||
"""LDAP Sync tasks"""
|
||||
from typing import Optional
|
||||
|
||||
from django.utils.text import slugify
|
||||
from ldap3.core.exceptions import LDAPException
|
||||
from structlog.stdlib import get_logger
|
||||
@ -31,8 +29,7 @@ def ldap_sync_all():
|
||||
@CELERY_APP.task(
|
||||
bind=True, base=MonitoredTask, soft_time_limit=60 * 60 * 2, task_time_limit=60 * 60 * 2
|
||||
)
|
||||
# TODO: remove Optional[str] in 2021.10
|
||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: Optional[str] = None):
|
||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
|
||||
"""Synchronization of an LDAP Source"""
|
||||
self.result_timeout_hours = 2
|
||||
try:
|
||||
@ -41,8 +38,6 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: Optional[str] = N
|
||||
# Because the source couldn't be found, we don't have a UID
|
||||
# to set the state with
|
||||
return
|
||||
if not sync_class:
|
||||
return
|
||||
sync = path_to_class(sync_class)
|
||||
self.set_uid(f"{slugify(source.name)}-{sync.__name__}")
|
||||
try:
|
||||
|
@ -12,6 +12,7 @@ class DiscordOAuthRedirect(OAuthRedirect):
|
||||
def get_additional_parameters(self, source): # pragma: no cover
|
||||
return {
|
||||
"scope": "email identify",
|
||||
"prompt": "none",
|
||||
}
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_sche
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ValidationError
|
||||
@ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.flows.challenge import RedirectChallenge
|
||||
from authentik.flows.views.executor import to_stage_response
|
||||
from authentik.sources.plex.models import PlexSource
|
||||
from authentik.sources.plex.models import PlexSource, PlexSourceConnection
|
||||
from authentik.sources.plex.plex import PlexAuth, PlexSourceFlowManager
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -98,21 +98,11 @@ class PlexSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
user_info, identifier = auth_api.get_user_info()
|
||||
# Check friendship first, then check server overlay
|
||||
friends_allowed = False
|
||||
owner_id = None
|
||||
if source.allow_friends:
|
||||
owner_api = PlexAuth(source, source.plex_token)
|
||||
owner_id = owner_api.get_user_info
|
||||
owner_friends = owner_api.get_friends()
|
||||
for friend in owner_friends:
|
||||
if int(friend.get("id", "0")) == int(identifier):
|
||||
friends_allowed = True
|
||||
LOGGER.info(
|
||||
"allowing user for plex because of friend",
|
||||
user=user_info["username"],
|
||||
)
|
||||
friends_allowed = owner_api.check_friends_overlap(identifier)
|
||||
servers_allowed = auth_api.check_server_overlap()
|
||||
owner_allowed = owner_id == identifier
|
||||
if any([friends_allowed, servers_allowed, owner_allowed]):
|
||||
if any([friends_allowed, servers_allowed]):
|
||||
sfm = PlexSourceFlowManager(
|
||||
source=source,
|
||||
request=request,
|
||||
@ -125,3 +115,57 @@ class PlexSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
user=user_info["username"],
|
||||
)
|
||||
raise PermissionDenied("Access denied.")
|
||||
|
||||
@extend_schema(
|
||||
request=PlexTokenRedeemSerializer(),
|
||||
responses={
|
||||
204: OpenApiResponse(),
|
||||
400: OpenApiResponse(description="Token not found"),
|
||||
403: OpenApiResponse(description="Access denied"),
|
||||
},
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="slug",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.STR,
|
||||
)
|
||||
],
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=False,
|
||||
pagination_class=None,
|
||||
filter_backends=[],
|
||||
permission_classes=[IsAuthenticated],
|
||||
)
|
||||
def redeem_token_authenticated(self, request: Request) -> Response:
|
||||
"""Redeem a plex token for an authenticated user, creating a connection"""
|
||||
source: PlexSource = get_object_or_404(
|
||||
PlexSource, slug=request.query_params.get("slug", "")
|
||||
)
|
||||
plex_token = request.data.get("plex_token", None)
|
||||
if not plex_token:
|
||||
raise ValidationError("No plex token given")
|
||||
auth_api = PlexAuth(source, plex_token)
|
||||
user_info, identifier = auth_api.get_user_info()
|
||||
# Check friendship first, then check server overlay
|
||||
friends_allowed = False
|
||||
if source.allow_friends:
|
||||
owner_api = PlexAuth(source, source.plex_token)
|
||||
friends_allowed = owner_api.check_friends_overlap(identifier)
|
||||
servers_allowed = auth_api.check_server_overlap()
|
||||
if any([friends_allowed, servers_allowed]):
|
||||
PlexSourceConnection.objects.create(
|
||||
plex_token=plex_token,
|
||||
user=request.user,
|
||||
identifier=identifier,
|
||||
source=source,
|
||||
)
|
||||
return Response(status=204)
|
||||
LOGGER.warning(
|
||||
"Denying plex connection because no server overlay and no friends and not owner",
|
||||
user=user_info["username"],
|
||||
friends_allowed=friends_allowed,
|
||||
servers_allowed=servers_allowed,
|
||||
)
|
||||
raise PermissionDenied("Access denied.")
|
||||
|
@ -42,3 +42,4 @@ class PlexSourceConnectionViewSet(
|
||||
filterset_fields = ["source__slug"]
|
||||
permission_classes = [OwnerPermissions]
|
||||
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
|
||||
ordering = ["pk"]
|
||||
|
@ -83,6 +83,7 @@ class PlexSource(Source):
|
||||
data={
|
||||
"title": f"Plex {self.name}",
|
||||
"component": "ak-user-settings-source-plex",
|
||||
"configure_url": self.client_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -36,7 +36,7 @@ class PlexAuth:
|
||||
return {
|
||||
"X-Plex-Product": "authentik",
|
||||
"X-Plex-Version": __version__,
|
||||
"X-Plex-Device-Vendor": "BeryJu.org",
|
||||
"X-Plex-Device-Vendor": "goauthentik.io",
|
||||
}
|
||||
|
||||
def get_resources(self) -> list[dict]:
|
||||
@ -96,6 +96,21 @@ class PlexAuth:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_friends_overlap(self, user_ident: int) -> bool:
|
||||
"""Check if the user is a friend of the owner, or the owner themselves"""
|
||||
friends_allowed = False
|
||||
_, owner_id = self.get_user_info()
|
||||
owner_friends = self.get_friends()
|
||||
for friend in owner_friends:
|
||||
if int(friend.get("id", "0")) == user_ident:
|
||||
friends_allowed = True
|
||||
LOGGER.info(
|
||||
"allowing user for plex because of friend",
|
||||
user=user_ident,
|
||||
)
|
||||
owner_allowed = owner_id == user_ident
|
||||
return any([friends_allowed, owner_allowed])
|
||||
|
||||
|
||||
class PlexSourceFlowManager(SourceFlowManager):
|
||||
"""Flow manager for plex sources"""
|
||||
|
@ -3,6 +3,7 @@ from requests import RequestException
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.plex.models import PlexSource
|
||||
from authentik.sources.plex.plex import PlexAuth
|
||||
@ -31,7 +32,7 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.ERROR,
|
||||
["Plex token is invalid/an error occurred:", str(exc)],
|
||||
["Plex token is invalid/an error occurred:", exception_to_string(exc)],
|
||||
)
|
||||
)
|
||||
Event.new(
|
||||
|
@ -0,0 +1,18 @@
|
||||
# Generated by Django 3.2.8 on 2021-10-31 16:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_stages_authenticator_sms", "0001_squashed_0004_auto_20211014_0936"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="authenticatorsmsstage",
|
||||
name="from_number",
|
||||
field=models.TextField(),
|
||||
),
|
||||
]
|
@ -101,7 +101,7 @@ class AuthenticatorSMSStageView(ChallengeStageView):
|
||||
stage: AuthenticatorSMSStage = self.executor.current_stage
|
||||
|
||||
if SESSION_SMS_DEVICE not in self.request.session:
|
||||
device = SMSDevice(user=user, confirmed=False, stage=stage)
|
||||
device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device")
|
||||
device.generate_token(commit=False)
|
||||
if phone_number := self._has_phone_number():
|
||||
device.phone_number = phone_number
|
||||
|
@ -55,7 +55,7 @@ class AuthenticatorStaticStageView(ChallengeStageView):
|
||||
stage: AuthenticatorStaticStage = self.executor.current_stage
|
||||
|
||||
if SESSION_STATIC_DEVICE not in self.request.session:
|
||||
device = StaticDevice(user=user, confirmed=True)
|
||||
device = StaticDevice(user=user, confirmed=True, name="Static Token")
|
||||
tokens = []
|
||||
for _ in range(0, stage.token_count):
|
||||
tokens.append(StaticToken(device=device, token=StaticToken.random_token()))
|
||||
|
@ -81,7 +81,9 @@ class AuthenticatorTOTPStageView(ChallengeStageView):
|
||||
stage: AuthenticatorTOTPStage = self.executor.current_stage
|
||||
|
||||
if SESSION_TOTP_DEVICE not in self.request.session:
|
||||
device = TOTPDevice(user=user, confirmed=True, digits=stage.digits)
|
||||
device = TOTPDevice(
|
||||
user=user, confirmed=True, digits=stage.digits, name="TOTP Authenticator"
|
||||
)
|
||||
|
||||
self.request.session[SESSION_TOTP_DEVICE] = device
|
||||
return super().get(request, *args, **kwargs)
|
||||
|
@ -75,6 +75,7 @@ class AuthenticatorValidateStageTests(APITestCase):
|
||||
},
|
||||
"user_fields": ["username"],
|
||||
"sources": [],
|
||||
"show_source_labels": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -136,6 +136,7 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
|
||||
credential_id=bytes_to_base64url(webauthn_credential.credential_id),
|
||||
sign_count=webauthn_credential.sign_count,
|
||||
rp_id=get_rp_id(self.request),
|
||||
name="WebAuthn Device",
|
||||
)
|
||||
else:
|
||||
return self.executor.stage_invalid("Device with Credential ID already exists.")
|
||||
|
@ -20,6 +20,7 @@ class IdentificationStageSerializer(StageSerializer):
|
||||
"enrollment_flow",
|
||||
"recovery_flow",
|
||||
"sources",
|
||||
"show_source_labels",
|
||||
]
|
||||
|
||||
|
||||
@ -35,5 +36,6 @@ class IdentificationStageViewSet(UsedByMixin, ModelViewSet):
|
||||
"show_matched_user",
|
||||
"enrollment_flow",
|
||||
"recovery_flow",
|
||||
"show_source_labels",
|
||||
]
|
||||
ordering = ["name"]
|
||||
|
@ -0,0 +1,18 @@
|
||||
# Generated by Django 3.2.8 on 2021-10-31 16:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_stages_identification", "0011_alter_identificationstage_user_fields"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="identificationstage",
|
||||
name="show_source_labels",
|
||||
field=models.BooleanField(default=False),
|
||||
),
|
||||
]
|
@ -81,6 +81,7 @@ class IdentificationStage(Stage):
|
||||
sources = models.ManyToManyField(
|
||||
Source, default=list, help_text=_("Specify which sources should be shown.")
|
||||
)
|
||||
show_source_labels = models.BooleanField(default=False)
|
||||
|
||||
@property
|
||||
def serializer(self) -> BaseSerializer:
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Identification stage logic"""
|
||||
from dataclasses import asdict
|
||||
from random import SystemRandom
|
||||
from time import sleep
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -15,10 +16,16 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.models import Application, Source, User
|
||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||
from authentik.flows.challenge import (
|
||||
Challenge,
|
||||
ChallengeResponse,
|
||||
ChallengeTypes,
|
||||
RedirectChallenge,
|
||||
)
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, ChallengeStageView
|
||||
from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE, challenge_types
|
||||
from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE
|
||||
from authentik.sources.plex.models import PlexAuthenticationChallenge
|
||||
from authentik.stages.identification.models import IdentificationStage
|
||||
from authentik.stages.identification.signals import identification_failed
|
||||
from authentik.stages.password.stage import authenticate
|
||||
@ -28,8 +35,11 @@ LOGGER = get_logger()
|
||||
|
||||
@extend_schema_field(
|
||||
PolymorphicProxySerializer(
|
||||
component_name="ChallengeTypes",
|
||||
serializers=challenge_types(),
|
||||
component_name="LoginChallengeTypes",
|
||||
serializers={
|
||||
RedirectChallenge().fields["component"].default: RedirectChallenge,
|
||||
PlexAuthenticationChallenge().fields["component"].default: PlexAuthenticationChallenge,
|
||||
},
|
||||
resource_type_field_name="component",
|
||||
)
|
||||
)
|
||||
@ -57,6 +67,7 @@ class IdentificationChallenge(Challenge):
|
||||
recovery_url = CharField(required=False)
|
||||
primary_action = CharField()
|
||||
sources = LoginSourceSerializer(many=True, required=False)
|
||||
show_source_labels = BooleanField()
|
||||
|
||||
component = CharField(default="ak-stage-identification")
|
||||
|
||||
@ -77,7 +88,8 @@ class IdentificationChallengeResponse(ChallengeResponse):
|
||||
|
||||
pre_user = self.stage.get_user(uid_field)
|
||||
if not pre_user:
|
||||
sleep(0.150)
|
||||
# Sleep a random time (between 90 and 210ms) to "prevent" user enumeration attacks
|
||||
sleep(0.30 * SystemRandom().randint(3, 7))
|
||||
LOGGER.debug("invalid_login", identifier=uid_field)
|
||||
identification_failed.send(sender=self, request=self.stage.request, uid_field=uid_field)
|
||||
# We set the pending_user even on failure so it's part of the context, even
|
||||
@ -152,6 +164,7 @@ class IdentificationStageView(ChallengeStageView):
|
||||
"component": "ak-stage-identification",
|
||||
"user_fields": current_stage.user_fields,
|
||||
"password_fields": bool(current_stage.password_stage),
|
||||
"show_source_labels": current_stage.show_source_labels,
|
||||
}
|
||||
)
|
||||
# If the user has been redirected to us whilst trying to access an
|
||||
|
@ -123,6 +123,7 @@ class TestIdentificationStage(APITestCase):
|
||||
"name": "test",
|
||||
}
|
||||
],
|
||||
"show_source_labels": False,
|
||||
"user_fields": ["email"],
|
||||
},
|
||||
)
|
||||
@ -158,6 +159,7 @@ class TestIdentificationStage(APITestCase):
|
||||
{"code": "invalid", "string": "Failed to " "authenticate."}
|
||||
]
|
||||
},
|
||||
"show_source_labels": False,
|
||||
"flow_info": {
|
||||
"background": self.flow.background_url,
|
||||
"cancel_url": reverse("authentik_flows:cancel"),
|
||||
@ -218,6 +220,7 @@ class TestIdentificationStage(APITestCase):
|
||||
"authentik_core:if-flow",
|
||||
kwargs={"flow_slug": "unique-enrollment-string"},
|
||||
),
|
||||
"show_source_labels": False,
|
||||
"primary_action": "Log in",
|
||||
"flow_info": {
|
||||
"background": flow.background_url,
|
||||
@ -267,6 +270,7 @@ class TestIdentificationStage(APITestCase):
|
||||
"authentik_core:if-flow",
|
||||
kwargs={"flow_slug": "unique-recovery-string"},
|
||||
),
|
||||
"show_source_labels": False,
|
||||
"primary_action": "Log in",
|
||||
"flow_info": {
|
||||
"background": flow.background_url,
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Invitation Stage API Views"""
|
||||
from django_filters.filters import BooleanFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from rest_framework.fields import JSONField
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
@ -21,12 +23,23 @@ class InvitationStageSerializer(StageSerializer):
|
||||
]
|
||||
|
||||
|
||||
class InvitationStageFilter(FilterSet):
|
||||
"""invitation filter"""
|
||||
|
||||
no_flows = BooleanFilter("flow", "isnull")
|
||||
|
||||
class Meta:
|
||||
|
||||
model = InvitationStage
|
||||
fields = ["name", "no_flows", "continue_flow_without_invitation", "stage_uuid"]
|
||||
|
||||
|
||||
class InvitationStageViewSet(UsedByMixin, ModelViewSet):
|
||||
"""InvitationStage Viewset"""
|
||||
|
||||
queryset = InvitationStage.objects.all()
|
||||
serializer_class = InvitationStageSerializer
|
||||
filterset_fields = "__all__"
|
||||
filterset_class = InvitationStageFilter
|
||||
ordering = ["name"]
|
||||
|
||||
|
||||
@ -53,7 +66,7 @@ class InvitationViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
queryset = Invitation.objects.all()
|
||||
serializer_class = InvitationSerializer
|
||||
order = ["-expires"]
|
||||
ordering = ["-expires"]
|
||||
search_fields = ["created_by__username", "expires"]
|
||||
filterset_fields = ["created_by__username", "expires"]
|
||||
|
||||
|
@ -4,7 +4,6 @@ from typing import Optional
|
||||
from deepmerge import always_merger
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http.response import HttpResponseBadRequest
|
||||
from django.shortcuts import get_object_or_404
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.flows.models import in_memory_stage
|
||||
@ -50,7 +49,12 @@ class InvitationStageView(StageView):
|
||||
return self.executor.stage_ok()
|
||||
return self.executor.stage_invalid()
|
||||
|
||||
invite: Invitation = get_object_or_404(Invitation, pk=token)
|
||||
invite: Invitation = Invitation.objects.filter(pk=token).first()
|
||||
if not invite:
|
||||
LOGGER.debug("invalid invitation", token=token)
|
||||
if stage.continue_flow_without_invitation:
|
||||
return self.executor.stage_ok()
|
||||
return self.executor.stage_invalid()
|
||||
self.executor.plan.context[INVITATION_IN_EFFECT] = True
|
||||
self.executor.plan.context[INVITATION] = invite
|
||||
|
||||
@ -79,7 +83,9 @@ class InvitationFinalStageView(StageView):
|
||||
if not invitation:
|
||||
LOGGER.warning("InvitationFinalStageView stage called without invitation")
|
||||
return HttpResponseBadRequest
|
||||
if not invitation.single_use:
|
||||
return self.executor.stage_ok()
|
||||
invitation.delete()
|
||||
token = invitation.invite_uuid.hex
|
||||
if invitation.single_use:
|
||||
invitation.delete()
|
||||
LOGGER.debug("Deleted invitation", token=token)
|
||||
del self.executor.plan.context[INVITATION]
|
||||
return self.executor.stage_ok()
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""prompt models"""
|
||||
from typing import Type
|
||||
from typing import Any, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
@ -13,6 +13,7 @@ from rest_framework.fields import (
|
||||
EmailField,
|
||||
HiddenField,
|
||||
IntegerField,
|
||||
ReadOnlyField,
|
||||
)
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
|
||||
@ -26,6 +27,10 @@ class FieldTypes(models.TextChoices):
|
||||
|
||||
# Simple text field
|
||||
TEXT = "text", _("Text: Simple Text input")
|
||||
# Simple text field
|
||||
TEXT_READ_ONLY = "text_read_only", _(
|
||||
"Text (read-only): Simple Text input, but cannot be edited."
|
||||
)
|
||||
# Same as text, but has autocomplete for password managers
|
||||
USERNAME = (
|
||||
"username",
|
||||
@ -74,13 +79,15 @@ class Prompt(SerializerModel):
|
||||
|
||||
return PromptSerializer
|
||||
|
||||
@property
|
||||
def field(self) -> CharField:
|
||||
def field(self, default: Optional[Any]) -> CharField:
|
||||
"""Get field type for Challenge and response"""
|
||||
field_class = CharField
|
||||
kwargs = {
|
||||
"required": self.required,
|
||||
}
|
||||
|
||||
if self.type == FieldTypes.TEXT_READ_ONLY:
|
||||
field_class = ReadOnlyField
|
||||
if self.type == FieldTypes.EMAIL:
|
||||
field_class = EmailField
|
||||
if self.type == FieldTypes.NUMBER:
|
||||
@ -97,12 +104,14 @@ class Prompt(SerializerModel):
|
||||
if self.type == FieldTypes.DATE_TIME:
|
||||
field_class = DateTimeField
|
||||
if self.type == FieldTypes.STATIC:
|
||||
kwargs["initial"] = self.placeholder
|
||||
kwargs["default"] = self.placeholder
|
||||
kwargs["required"] = False
|
||||
kwargs["label"] = ""
|
||||
if self.type == FieldTypes.SEPARATOR:
|
||||
kwargs["required"] = False
|
||||
kwargs["label"] = ""
|
||||
if default:
|
||||
kwargs["default"] = default
|
||||
return field_class(**kwargs)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
|
@ -8,7 +8,7 @@ from django.http import HttpRequest, HttpResponse
|
||||
from django.http.request import QueryDict
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.fields import BooleanField, CharField, IntegerField
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@ -31,7 +31,7 @@ class StagePromptSerializer(PassiveSerializer):
|
||||
|
||||
field_key = CharField()
|
||||
label = CharField(allow_blank=True)
|
||||
type = CharField()
|
||||
type = ChoiceField(choices=FieldTypes.choices)
|
||||
required = BooleanField()
|
||||
placeholder = CharField(allow_blank=True)
|
||||
order = IntegerField()
|
||||
@ -65,7 +65,8 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||
fields = list(self.stage.fields.all())
|
||||
for field in fields:
|
||||
field: Prompt
|
||||
self.fields[field.field_key] = field.field
|
||||
current = plan.context.get(PLAN_CONTEXT_PROMPT, {}).get(field.field_key)
|
||||
self.fields[field.field_key] = field.field(current)
|
||||
# Special handling for fields with username type
|
||||
# these check for existing users with the same username
|
||||
if field.type == FieldTypes.USERNAME:
|
||||
@ -96,10 +97,11 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||
# Check if we have any static or hidden fields, and ensure they
|
||||
# still have the same value
|
||||
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
||||
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC]
|
||||
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY]
|
||||
)
|
||||
for static_hidden in static_hidden_fields:
|
||||
attrs[static_hidden.field_key] = static_hidden.placeholder
|
||||
field = self.fields[static_hidden.field_key]
|
||||
attrs[static_hidden.field_key] = field.default
|
||||
|
||||
# Check if we have two password fields, and make sure they are the same
|
||||
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
|
||||
@ -163,10 +165,17 @@ class PromptStageView(ChallengeStageView):
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||
fields = list(self.executor.current_stage.fields.all().order_by("order"))
|
||||
serializers = []
|
||||
context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {})
|
||||
for field in fields:
|
||||
data = StagePromptSerializer(field).data
|
||||
if field.field_key in context_prompt:
|
||||
data["placeholder"] = context_prompt.get(field.field_key)
|
||||
serializers.append(data)
|
||||
challenge = PromptChallenge(
|
||||
data={
|
||||
"type": ChallengeTypes.NATIVE.value,
|
||||
"fields": [StagePromptSerializer(field).data for field in fields],
|
||||
"fields": serializers,
|
||||
},
|
||||
)
|
||||
return challenge
|
||||
|
@ -21,31 +21,32 @@ var running = true
|
||||
func main() {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
l := log.WithField("logger", "authentik.root")
|
||||
config.DefaultConfig()
|
||||
err := config.LoadConfig("./authentik/lib/default.yml")
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to load default config")
|
||||
l.WithError(err).Warning("failed to load default config")
|
||||
}
|
||||
err = config.LoadConfig("./local.env.yml")
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("no local config to load")
|
||||
l.WithError(err).Debug("no local config to load")
|
||||
}
|
||||
err = config.FromEnv()
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("failed to environment variables")
|
||||
l.WithError(err).Debug("failed to environment variables")
|
||||
}
|
||||
config.ConfigureLogger()
|
||||
|
||||
if config.G.ErrorReporting.Enabled {
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: "https://a579bb09306d4f8b8d8847c052d3a1d3@sentry.beryju.org/8",
|
||||
Dsn: config.G.ErrorReporting.DSN,
|
||||
AttachStacktrace: true,
|
||||
TracesSampleRate: 0.6,
|
||||
Release: fmt.Sprintf("authentik@%s", constants.VERSION),
|
||||
Environment: config.G.ErrorReporting.Environment,
|
||||
})
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to init sentry")
|
||||
l.WithError(err).Warning("failed to init sentry")
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,31 +57,31 @@ func main() {
|
||||
|
||||
g := gounicorn.NewGoUnicorn()
|
||||
ws := web.NewWebServer(g)
|
||||
defer g.Kill()
|
||||
defer ws.Shutdown()
|
||||
g.HealthyCallback = func() {
|
||||
if !config.G.Web.DisableEmbeddedOutpost {
|
||||
go attemptProxyStart(ws, u)
|
||||
}
|
||||
}
|
||||
go web.RunMetricsServer()
|
||||
for {
|
||||
go attemptStartBackend(g)
|
||||
ws.Start()
|
||||
if !config.G.Web.DisableEmbeddedOutpost {
|
||||
go attemptProxyStart(ws, u)
|
||||
}
|
||||
|
||||
<-ex
|
||||
running = false
|
||||
log.WithField("logger", "authentik").Info("shutting down webserver")
|
||||
l.WithField("logger", "authentik").Info("shutting down gunicorn")
|
||||
go g.Kill()
|
||||
l.WithField("logger", "authentik").Info("shutting down webserver")
|
||||
go ws.Shutdown()
|
||||
log.WithField("logger", "authentik").Info("killing gunicorn")
|
||||
g.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
func attemptStartBackend(g *gounicorn.GoUnicorn) {
|
||||
for {
|
||||
err := g.Start()
|
||||
if !running {
|
||||
return
|
||||
}
|
||||
err := g.Start()
|
||||
log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting")
|
||||
}
|
||||
}
|
||||
@ -88,8 +89,6 @@ func attemptStartBackend(g *gounicorn.GoUnicorn) {
|
||||
func attemptProxyStart(ws *web.WebServer, u *url.URL) {
|
||||
maxTries := 100
|
||||
attempt := 0
|
||||
// Sleep to wait for the app server to start
|
||||
time.Sleep(30 * time.Second)
|
||||
for {
|
||||
log.WithField("logger", "authentik").Debug("attempting to init outpost")
|
||||
ac := ak.NewAPIController(*u, config.G.SecretKey)
|
||||
|
@ -17,7 +17,7 @@ services:
|
||||
image: redis:alpine
|
||||
restart: unless-stopped
|
||||
server:
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.10.1}
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.10.3}
|
||||
restart: unless-stopped
|
||||
command: server
|
||||
environment:
|
||||
@ -38,7 +38,7 @@ services:
|
||||
- "0.0.0.0:9000:9000"
|
||||
- "0.0.0.0:9443:9443"
|
||||
worker:
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.10.1}
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.10.3}
|
||||
restart: unless-stopped
|
||||
command: worker
|
||||
environment:
|
||||
|
2
go.mod
2
go.mod
@ -29,7 +29,7 @@ require (
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/recws-org/recws v1.3.1
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
goauthentik.io/api v0.2021101.5
|
||||
goauthentik.io/api v0.2021102.6
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20210323180902-22b0adad7558
|
||||
|
4
go.sum
4
go.sum
@ -559,8 +559,8 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
goauthentik.io/api v0.2021101.5 h1:OLaI37B+2GrjvIuK7XcacCsV9LSyKJft2e7ghTVqaJ0=
|
||||
goauthentik.io/api v0.2021101.5/go.mod h1:02nnD4FRd8lu8A1+ZuzqownBgvAhdCKzqkKX8v7JMTE=
|
||||
goauthentik.io/api v0.2021102.6 h1:jDah4AH28snsmFl9RwRMKrQ1ayW4zXrQWinryQdlJwA=
|
||||
goauthentik.io/api v0.2021102.6/go.mod h1:02nnD4FRd8lu8A1+ZuzqownBgvAhdCKzqkKX8v7JMTE=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
|
@ -26,6 +26,7 @@ func DefaultConfig() {
|
||||
LogLevel: "info",
|
||||
ErrorReporting: ErrorReportingConfig{
|
||||
Enabled: false,
|
||||
DSN: "https://a579bb09306d4f8b8d8847c052d3a1d3@sentry.beryju.org/8",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -42,4 +42,5 @@ type ErrorReportingConfig struct {
|
||||
Enabled bool `yaml:"enabled" env:"AUTHENTIK_ERROR_REPORTING__ENABLED"`
|
||||
Environment string `yaml:"environment" env:"AUTHENTIK_ERROR_REPORTING__ENVIRONMENT"`
|
||||
SendPII bool `yaml:"send_pii" env:"AUTHENTIK_ERROR_REPORTING__SEND_PII"`
|
||||
DSN string
|
||||
}
|
||||
|
@ -17,4 +17,4 @@ func OutpostUserAgent() string {
|
||||
return fmt.Sprintf("authentik-outpost@%s (build=%s)", VERSION, BUILD())
|
||||
}
|
||||
|
||||
const VERSION = "2021.10.1"
|
||||
const VERSION = "2021.10.3"
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -11,6 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type GoUnicorn struct {
|
||||
HealthyCallback func()
|
||||
|
||||
log *log.Entry
|
||||
p *exec.Cmd
|
||||
started bool
|
||||
@ -21,10 +25,11 @@ type GoUnicorn struct {
|
||||
func NewGoUnicorn() *GoUnicorn {
|
||||
logger := log.WithField("logger", "authentik.router.unicorn")
|
||||
g := &GoUnicorn{
|
||||
log: logger,
|
||||
started: false,
|
||||
killed: false,
|
||||
alive: false,
|
||||
log: logger,
|
||||
started: false,
|
||||
killed: false,
|
||||
alive: false,
|
||||
HealthyCallback: func() {},
|
||||
}
|
||||
g.initCmd()
|
||||
return g
|
||||
@ -46,7 +51,7 @@ func (g *GoUnicorn) IsRunning() bool {
|
||||
|
||||
func (g *GoUnicorn) Start() error {
|
||||
if g.killed {
|
||||
g.log.Debug("Not restarting gunicorn since we're killed")
|
||||
g.log.Debug("Not restarting gunicorn since we're shutdown")
|
||||
return nil
|
||||
}
|
||||
if g.started {
|
||||
@ -76,6 +81,7 @@ func (g *GoUnicorn) healthcheck() {
|
||||
for range time.Tick(time.Second) {
|
||||
if check() {
|
||||
g.log.Info("backend is alive, backing off with healthchecks")
|
||||
g.HealthyCallback()
|
||||
break
|
||||
}
|
||||
g.log.Debug("backend not alive yet")
|
||||
@ -87,8 +93,15 @@ func (g *GoUnicorn) healthcheck() {
|
||||
|
||||
func (g *GoUnicorn) Kill() {
|
||||
g.killed = true
|
||||
err := g.p.Process.Kill()
|
||||
var err error
|
||||
if runtime.GOOS == "darwin" {
|
||||
g.log.WithField("method", "kill").Warning("stopping gunicorn")
|
||||
err = g.p.Process.Kill()
|
||||
} else {
|
||||
g.log.WithField("method", "sigterm").Warning("stopping gunicorn")
|
||||
err = syscall.Kill(g.p.Process.Pid, syscall.SIGTERM)
|
||||
}
|
||||
if err != nil {
|
||||
g.log.WithError(err).Warning("failed to kill gunicorn")
|
||||
g.log.WithError(err).Warning("failed to stop gunicorn")
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
@ -18,7 +17,6 @@ import (
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
type StageComponent string
|
||||
@ -103,8 +101,8 @@ type ChallengeInt interface {
|
||||
GetResponseErrors() map[string][]api.ErrorDetail
|
||||
}
|
||||
|
||||
func (fe *FlowExecutor) DelegateClientIP(a net.Addr) {
|
||||
fe.cip = utils.GetIP(a)
|
||||
func (fe *FlowExecutor) DelegateClientIP(a string) {
|
||||
fe.cip = a
|
||||
fe.api.GetConfig().AddDefaultHeader(HeaderAuthentikRemoteIP, fe.cip)
|
||||
}
|
||||
|
||||
|
@ -1,23 +0,0 @@
|
||||
package ldap
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if len(ls.providers) == 1 {
|
||||
if ls.providers[0].cert != nil {
|
||||
ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert")
|
||||
return ls.providers[0].cert, nil
|
||||
}
|
||||
}
|
||||
for _, provider := range ls.providers {
|
||||
if provider.tlsServerName == &info.ServerName {
|
||||
if provider.cert == nil {
|
||||
ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
|
||||
return ls.defaultCert, nil
|
||||
}
|
||||
return provider.cert, nil
|
||||
}
|
||||
}
|
||||
ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert")
|
||||
return ls.defaultCert, nil
|
||||
}
|
@ -1,44 +1,18 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
type BindRequest struct {
|
||||
BindDN string
|
||||
BindPW string
|
||||
id string
|
||||
conn net.Conn
|
||||
log *log.Entry
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LDAPResultCode, error) {
|
||||
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind",
|
||||
sentry.TransactionName("authentik.providers.ldap.bind"))
|
||||
rid := uuid.New().String()
|
||||
span.SetTag("request_uid", rid)
|
||||
span.SetTag("user.username", bindDN)
|
||||
req, span := bind.NewRequest(bindDN, bindPW, conn)
|
||||
|
||||
bindDN = strings.ToLower(bindDN)
|
||||
req := BindRequest{
|
||||
BindDN: bindDN,
|
||||
BindPW: bindPW,
|
||||
conn: conn,
|
||||
log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())),
|
||||
id: rid,
|
||||
ctx: span.Context(),
|
||||
}
|
||||
defer func() {
|
||||
span.Finish()
|
||||
metrics.Requests.With(prometheus.Labels{
|
||||
@ -46,19 +20,19 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
|
||||
"type": "bind",
|
||||
"filter": "",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": req.RemoteAddr(),
|
||||
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
|
||||
req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request")
|
||||
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request")
|
||||
}()
|
||||
for _, instance := range ls.providers {
|
||||
username, err := instance.getUsername(bindDN)
|
||||
username, err := instance.binder.GetUsername(bindDN)
|
||||
if err == nil {
|
||||
return instance.Bind(username, req)
|
||||
return instance.binder.Bind(username, req)
|
||||
} else {
|
||||
req.log.WithError(err).Debug("Username not for instance")
|
||||
req.Log().WithError(err).Debug("Username not for instance")
|
||||
}
|
||||
}
|
||||
req.log.WithField("request", "bind").Warning("No provider found for request")
|
||||
req.Log().WithField("request", "bind").Warning("No provider found for request")
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ls.ac.Outpost.Name,
|
||||
"type": "bind",
|
||||
@ -68,10 +42,3 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
|
||||
}).Inc()
|
||||
return ldap.LDAPResultOperationsError, nil
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) TimerFlowCacheExpiry() {
|
||||
for _, p := range ls.providers {
|
||||
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
|
||||
p.TimerFlowCacheExpiry()
|
||||
}
|
||||
}
|
||||
|
9
internal/outpost/ldap/bind/binder.go
Normal file
9
internal/outpost/ldap/bind/binder.go
Normal file
@ -0,0 +1,9 @@
|
||||
package bind
|
||||
|
||||
import "github.com/nmcclain/ldap"
|
||||
|
||||
type Binder interface {
|
||||
GetUsername(string) (string, error)
|
||||
Bind(username string, req *Request) (ldap.LDAPResultCode, error)
|
||||
TimerFlowCacheExpiry()
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package ldap
|
||||
package direct
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -12,14 +12,30 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/utils"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
)
|
||||
|
||||
const ContextUserKey = "ak_user"
|
||||
|
||||
func (pi *ProviderInstance) getUsername(dn string) (string, error) {
|
||||
if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(pi.BaseDN)) {
|
||||
type DirectBinder struct {
|
||||
si server.LDAPServerInstance
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
func NewDirectBinder(si server.LDAPServerInstance) *DirectBinder {
|
||||
db := &DirectBinder{
|
||||
si: si,
|
||||
log: log.WithField("logger", "authentik.outpost.ldap.binder.direct"),
|
||||
}
|
||||
db.log.Info("initialised direct binder")
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DirectBinder) GetUsername(dn string) (string, error) {
|
||||
if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(db.si.GetBaseDN())) {
|
||||
return "", errors.New("invalid base DN")
|
||||
}
|
||||
dns, err := goldap.ParseDN(dn)
|
||||
@ -36,13 +52,13 @@ func (pi *ProviderInstance) getUsername(dn string) (string, error) {
|
||||
return "", errors.New("failed to find cn")
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPResultCode, error) {
|
||||
fe := outpost.NewFlowExecutor(req.ctx, pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{
|
||||
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
|
||||
fe := outpost.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
|
||||
"bindDN": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"requestId": req.id,
|
||||
"client": req.RemoteAddr(),
|
||||
"requestId": req.ID(),
|
||||
})
|
||||
fe.DelegateClientIP(req.conn.RemoteAddr())
|
||||
fe.DelegateClientIP(req.RemoteAddr())
|
||||
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
||||
|
||||
fe.Answers[outpost.StageIdentification] = username
|
||||
@ -51,83 +67,82 @@ func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPRes
|
||||
passed, err := fe.Execute()
|
||||
if !passed {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"outpost_name": db.si.GetOutpostName(),
|
||||
"type": "bind",
|
||||
"reason": "invalid_credentials",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.LDAPResultInvalidCredentials, nil
|
||||
}
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"outpost_name": db.si.GetOutpostName(),
|
||||
"type": "bind",
|
||||
"reason": "flow_error",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
req.log.WithError(err).Warning("failed to execute flow")
|
||||
req.Log().WithError(err).Warning("failed to execute flow")
|
||||
return ldap.LDAPResultOperationsError, nil
|
||||
}
|
||||
|
||||
access, err := fe.CheckApplicationAccess(pi.appSlug)
|
||||
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
|
||||
if !access {
|
||||
req.log.Info("Access denied for user")
|
||||
req.Log().Info("Access denied for user")
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"outpost_name": db.si.GetOutpostName(),
|
||||
"type": "bind",
|
||||
"reason": "access_denied",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.LDAPResultInsufficientAccessRights, nil
|
||||
}
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"outpost_name": db.si.GetOutpostName(),
|
||||
"type": "bind",
|
||||
"reason": "access_check_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
req.log.WithError(err).Warning("failed to check access")
|
||||
req.Log().WithError(err).Warning("failed to check access")
|
||||
return ldap.LDAPResultOperationsError, nil
|
||||
}
|
||||
req.log.Info("User has access")
|
||||
uisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.bind.user_info")
|
||||
req.Log().Info("User has access")
|
||||
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
|
||||
// Get user info to store in context
|
||||
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"outpost_name": db.si.GetOutpostName(),
|
||||
"type": "bind",
|
||||
"reason": "user_info_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
req.log.WithError(err).Warning("failed to get user info")
|
||||
req.Log().WithError(err).Warning("failed to get user info")
|
||||
return ldap.LDAPResultOperationsError, nil
|
||||
}
|
||||
pi.boundUsersMutex.Lock()
|
||||
cs := pi.SearchAccessCheck(userInfo.User)
|
||||
pi.boundUsers[req.BindDN] = UserFlags{
|
||||
cs := db.SearchAccessCheck(userInfo.User)
|
||||
flags := flags.UserFlags{
|
||||
UserPk: userInfo.User.Pk,
|
||||
CanSearch: cs != nil,
|
||||
}
|
||||
if pi.boundUsers[req.BindDN].CanSearch {
|
||||
req.log.WithField("group", cs).Info("Allowed access to search")
|
||||
db.si.SetFlags(req.BindDN, flags)
|
||||
if flags.CanSearch {
|
||||
req.Log().WithField("group", cs).Info("Allowed access to search")
|
||||
}
|
||||
uisp.Finish()
|
||||
defer pi.boundUsersMutex.Unlock()
|
||||
return ldap.LDAPResultSuccess, nil
|
||||
}
|
||||
|
||||
// SearchAccessCheck Check if the current user is allowed to search
|
||||
func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string {
|
||||
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
|
||||
for _, group := range user.Groups {
|
||||
for _, allowedGroup := range pi.searchAllowedGroups {
|
||||
pi.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access")
|
||||
for _, allowedGroup := range db.si.GetSearchAllowedGroups() {
|
||||
db.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access")
|
||||
if group.Pk == allowedGroup.String() {
|
||||
return &group.Name
|
||||
}
|
||||
@ -136,13 +151,13 @@ func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) TimerFlowCacheExpiry() {
|
||||
fe := outpost.NewFlowExecutor(context.Background(), pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{})
|
||||
func (db *DirectBinder) TimerFlowCacheExpiry() {
|
||||
fe := outpost.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
|
||||
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
||||
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")
|
||||
|
||||
err := fe.WarmUp()
|
||||
if err != nil {
|
||||
pi.log.WithError(err).Warning("failed to warm up flow cache")
|
||||
db.log.WithError(err).Warning("failed to warm up flow cache")
|
||||
}
|
||||
}
|
55
internal/outpost/ldap/bind/request.go
Normal file
55
internal/outpost/ldap/bind/request.go
Normal file
@ -0,0 +1,55 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
BindDN string
|
||||
BindPW string
|
||||
id string
|
||||
conn net.Conn
|
||||
log *log.Entry
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewRequest(bindDN string, bindPW string, conn net.Conn) (*Request, *sentry.Span) {
|
||||
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind",
|
||||
sentry.TransactionName("authentik.providers.ldap.bind"))
|
||||
rid := uuid.New().String()
|
||||
span.SetTag("request_uid", rid)
|
||||
span.SetTag("user.username", bindDN)
|
||||
|
||||
bindDN = strings.ToLower(bindDN)
|
||||
return &Request{
|
||||
BindDN: bindDN,
|
||||
BindPW: bindPW,
|
||||
conn: conn,
|
||||
log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())),
|
||||
id: rid,
|
||||
ctx: span.Context(),
|
||||
}, span
|
||||
}
|
||||
|
||||
func (r *Request) Context() context.Context {
|
||||
return r.ctx
|
||||
}
|
||||
|
||||
func (r *Request) Log() *log.Entry {
|
||||
return r.log
|
||||
}
|
||||
|
||||
func (r *Request) RemoteAddr() string {
|
||||
return utils.GetIP(r.conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (r *Request) ID() string {
|
||||
return r.id
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (ls *LDAPServer) Close(boundDN string, conn net.Conn) error {
|
||||
for _, p := range ls.providers {
|
||||
p.delayDeleteUserInfo(boundDN)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) delayDeleteUserInfo(dn string) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
quit := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
pi.boundUsersMutex.Lock()
|
||||
delete(pi.boundUsers, dn)
|
||||
pi.boundUsersMutex.Unlock()
|
||||
close(quit)
|
||||
case <-quit:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
21
internal/outpost/ldap/constants/constants.go
Normal file
21
internal/outpost/ldap/constants/constants.go
Normal file
@ -0,0 +1,21 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
OCGroup = "group"
|
||||
OCGroupOfUniqueNames = "groupOfUniqueNames"
|
||||
OCAKGroup = "goauthentik.io/ldap/group"
|
||||
OCAKVirtualGroup = "goauthentik.io/ldap/virtual-group"
|
||||
)
|
||||
|
||||
const (
|
||||
OCUser = "user"
|
||||
OCOrgPerson = "organizationalPerson"
|
||||
OCInetOrgPerson = "inetOrgPerson"
|
||||
OCAKUser = "goauthentik.io/ldap/user"
|
||||
)
|
||||
|
||||
const (
|
||||
OUUsers = "users"
|
||||
OUGroups = "groups"
|
||||
OUVirtualGroups = "virtual-groups"
|
||||
)
|
33
internal/outpost/ldap/entries.go
Normal file
33
internal/outpost/ldap/entries.go
Normal file
@ -0,0 +1,33 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
)
|
||||
|
||||
func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
|
||||
dn := pi.GetUserDN(u.Username)
|
||||
attrs := utils.AKAttrsToLDAP(u.Attributes)
|
||||
|
||||
attrs = utils.EnsureAttributes(attrs, map[string][]string{
|
||||
"memberOf": pi.GroupsForUser(u),
|
||||
// Old fields for backwards compatibility
|
||||
"accountStatus": {utils.BoolToString(*u.IsActive)},
|
||||
"superuser": {utils.BoolToString(u.IsSuperuser)},
|
||||
// End old fields
|
||||
"goauthentik.io/ldap/active": {utils.BoolToString(*u.IsActive)},
|
||||
"goauthentik.io/ldap/superuser": {utils.BoolToString(u.IsSuperuser)},
|
||||
"cn": {u.Username},
|
||||
"sAMAccountName": {u.Username},
|
||||
"uid": {u.Uid},
|
||||
"name": {u.Name},
|
||||
"displayName": {u.Name},
|
||||
"mail": {*u.Email},
|
||||
"objectClass": {constants.OCUser, constants.OCOrgPerson, constants.OCInetOrgPerson, constants.OCAKUser},
|
||||
"uidNumber": {pi.GetUidNumber(u)},
|
||||
"gidNumber": {pi.GetUidNumber(u)},
|
||||
})
|
||||
return &ldap.Entry{DN: dn, Attributes: attrs}
|
||||
}
|
9
internal/outpost/ldap/flags/flags.go
Normal file
9
internal/outpost/ldap/flags/flags.go
Normal file
@ -0,0 +1,9 @@
|
||||
package flags
|
||||
|
||||
import "goauthentik.io/api"
|
||||
|
||||
type UserFlags struct {
|
||||
UserInfo *api.User
|
||||
UserPk int32
|
||||
CanSearch bool
|
||||
}
|
66
internal/outpost/ldap/group/group.go
Normal file
66
internal/outpost/ldap/group/group.go
Normal file
@ -0,0 +1,66 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
)
|
||||
|
||||
type LDAPGroup struct {
|
||||
DN string
|
||||
CN string
|
||||
Uid string
|
||||
GidNumber string
|
||||
Member []string
|
||||
IsSuperuser bool
|
||||
IsVirtualGroup bool
|
||||
AKAttributes interface{}
|
||||
}
|
||||
|
||||
func (lg *LDAPGroup) Entry() *ldap.Entry {
|
||||
attrs := utils.AKAttrsToLDAP(lg.AKAttributes)
|
||||
|
||||
objectClass := []string{constants.OCGroup, constants.OCGroupOfUniqueNames, constants.OCAKGroup}
|
||||
if lg.IsVirtualGroup {
|
||||
objectClass = append(objectClass, constants.OCAKVirtualGroup)
|
||||
}
|
||||
|
||||
attrs = utils.EnsureAttributes(attrs, map[string][]string{
|
||||
"objectClass": objectClass,
|
||||
"member": lg.Member,
|
||||
"goauthentik.io/ldap/superuser": {utils.BoolToString(lg.IsSuperuser)},
|
||||
"cn": {lg.CN},
|
||||
"uid": {lg.Uid},
|
||||
"sAMAccountName": {lg.CN},
|
||||
"gidNumber": {lg.GidNumber},
|
||||
})
|
||||
return &ldap.Entry{DN: lg.DN, Attributes: attrs}
|
||||
}
|
||||
|
||||
func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup {
|
||||
return &LDAPGroup{
|
||||
DN: si.GetGroupDN(g.Name),
|
||||
CN: g.Name,
|
||||
Uid: string(g.Pk),
|
||||
GidNumber: si.GetGidNumber(g),
|
||||
Member: si.UsersForGroup(g),
|
||||
IsVirtualGroup: false,
|
||||
IsSuperuser: *g.IsSuperuser,
|
||||
AKAttributes: g.Attributes,
|
||||
}
|
||||
}
|
||||
|
||||
func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup {
|
||||
return &LDAPGroup{
|
||||
DN: si.GetVirtualGroupDN(u.Username),
|
||||
CN: u.Username,
|
||||
Uid: u.Uid,
|
||||
GidNumber: si.GetUidNumber(u),
|
||||
Member: []string{si.GetUserDN(u.Username)},
|
||||
IsVirtualGroup: true,
|
||||
IsSuperuser: false,
|
||||
AKAttributes: nil,
|
||||
}
|
||||
}
|
4
internal/outpost/ldap/handler/handler.go
Normal file
4
internal/outpost/ldap/handler/handler.go
Normal file
@ -0,0 +1,4 @@
|
||||
package handler
|
||||
|
||||
type Handler interface {
|
||||
}
|
83
internal/outpost/ldap/instance.go
Normal file
83
internal/outpost/ldap/instance.go
Normal file
@ -0,0 +1,83 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
)
|
||||
|
||||
type ProviderInstance struct {
|
||||
BaseDN string
|
||||
UserDN string
|
||||
VirtualGroupDN string
|
||||
GroupDN string
|
||||
|
||||
searcher search.Searcher
|
||||
binder bind.Binder
|
||||
|
||||
appSlug string
|
||||
flowSlug string
|
||||
s *LDAPServer
|
||||
log *log.Entry
|
||||
|
||||
tlsServerName *string
|
||||
cert *tls.Certificate
|
||||
outpostName string
|
||||
searchAllowedGroups []*strfmt.UUID
|
||||
boundUsersMutex sync.RWMutex
|
||||
boundUsers map[string]flags.UserFlags
|
||||
|
||||
uidStartNumber int32
|
||||
gidStartNumber int32
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetAPIClient() *api.APIClient {
|
||||
return pi.s.ac.Client
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetBaseDN() string {
|
||||
return pi.BaseDN
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetBaseGroupDN() string {
|
||||
return pi.GroupDN
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetBaseUserDN() string {
|
||||
return pi.UserDN
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetOutpostName() string {
|
||||
return pi.outpostName
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetFlags(dn string) (flags.UserFlags, bool) {
|
||||
pi.boundUsersMutex.RLock()
|
||||
flags, ok := pi.boundUsers[dn]
|
||||
pi.boundUsersMutex.RUnlock()
|
||||
return flags, ok
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) {
|
||||
pi.boundUsersMutex.Lock()
|
||||
pi.boundUsers[dn] = flag
|
||||
pi.boundUsersMutex.Unlock()
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetAppSlug() string {
|
||||
return pi.appSlug
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetFlowSlug() string {
|
||||
return pi.flowSlug
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {
|
||||
return pi.searchAllowedGroups
|
||||
}
|
@ -1,244 +0,0 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
func (pi *ProviderInstance) SearchMe(req SearchRequest, f UserFlags) (ldap.ServerSearchResult, error) {
|
||||
if f.UserInfo == nil {
|
||||
u, _, err := pi.s.ac.Client.CoreApi.CoreUsersRetrieve(req.ctx, f.UserPk).Execute()
|
||||
if err != nil {
|
||||
req.log.WithError(err).Warning("Failed to get user info")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
f.UserInfo = &u
|
||||
}
|
||||
entries := make([]*ldap.Entry, 1)
|
||||
entries[0] = pi.UserEntry(*f.UserInfo)
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, error) {
|
||||
accsp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.check_access")
|
||||
baseDN := strings.ToLower("," + pi.BaseDN)
|
||||
|
||||
entries := []*ldap.Entry{}
|
||||
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"type": "search",
|
||||
"reason": "filter_parse_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
}
|
||||
if len(req.BindDN) < 1 {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"type": "search",
|
||||
"reason": "empty_bind_dn",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
|
||||
}
|
||||
if !strings.HasSuffix(req.BindDN, baseDN) {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"type": "search",
|
||||
"reason": "invalid_bind_dn",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, pi.BaseDN)
|
||||
}
|
||||
|
||||
pi.boundUsersMutex.RLock()
|
||||
flags, ok := pi.boundUsers[req.BindDN]
|
||||
pi.boundUsersMutex.RUnlock()
|
||||
if !ok {
|
||||
pi.log.Debug("User info not cached")
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"type": "search",
|
||||
"reason": "user_info_not_cached",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
|
||||
}
|
||||
|
||||
if req.SearchRequest.Scope == ldap.ScopeBaseObject {
|
||||
pi.log.Debug("base scope, showing domain info")
|
||||
return pi.SearchBase(req, flags.CanSearch)
|
||||
}
|
||||
if !flags.CanSearch {
|
||||
pi.log.Debug("User can't search, showing info about user")
|
||||
return pi.SearchMe(req, flags)
|
||||
}
|
||||
accsp.Finish()
|
||||
|
||||
parsedFilter, err := ldap.CompileFilter(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"type": "search",
|
||||
"reason": "filter_parse_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
}
|
||||
|
||||
// Create a custom client to set additional headers
|
||||
c := api.NewAPIClient(pi.s.ac.Client.GetConfig())
|
||||
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
|
||||
|
||||
switch filterEntity {
|
||||
default:
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": pi.outpostName,
|
||||
"type": "search",
|
||||
"reason": "unhandled_filter_type",
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
|
||||
case "groupOfUniqueNames":
|
||||
fallthrough
|
||||
case "goauthentik.io/ldap/group":
|
||||
fallthrough
|
||||
case "goauthentik.io/ldap/virtual-group":
|
||||
fallthrough
|
||||
case GroupObjectClass:
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
gEntries := make([]*ldap.Entry, 0)
|
||||
uEntries := make([]*ldap.Entry, 0)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
gapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_group")
|
||||
searchReq, skip := parseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
pi.log.Trace("Skip backend request")
|
||||
return
|
||||
}
|
||||
groups, _, err := searchReq.Execute()
|
||||
gapisp.Finish()
|
||||
if err != nil {
|
||||
req.log.WithError(err).Warning("failed to get groups")
|
||||
return
|
||||
}
|
||||
pi.log.WithField("count", len(groups.Results)).Trace("Got results from API")
|
||||
|
||||
for _, g := range groups.Results {
|
||||
gEntries = append(gEntries, pi.GroupEntry(pi.APIGroupToLDAPGroup(g)))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
pi.log.Trace("Skip backend request")
|
||||
return
|
||||
}
|
||||
users, _, err := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
if err != nil {
|
||||
req.log.WithError(err).Warning("failed to get users")
|
||||
return
|
||||
}
|
||||
|
||||
for _, u := range users.Results {
|
||||
uEntries = append(uEntries, pi.GroupEntry(pi.APIUserToLDAPGroup(u)))
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
entries = append(gEntries, uEntries...)
|
||||
case "":
|
||||
fallthrough
|
||||
case "organizationalPerson":
|
||||
fallthrough
|
||||
case "inetOrgPerson":
|
||||
fallthrough
|
||||
case "goauthentik.io/ldap/user":
|
||||
fallthrough
|
||||
case UserObjectClass:
|
||||
uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
pi.log.Trace("Skip backend request")
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
users, _, err := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
|
||||
}
|
||||
for _, u := range users.Results {
|
||||
entries = append(entries, pi.UserEntry(u))
|
||||
}
|
||||
}
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
|
||||
dn := pi.GetUserDN(u.Username)
|
||||
attrs := AKAttrsToLDAP(u.Attributes)
|
||||
|
||||
attrs = pi.ensureAttributes(attrs, map[string][]string{
|
||||
"memberOf": pi.GroupsForUser(u),
|
||||
// Old fields for backwards compatibility
|
||||
"accountStatus": {BoolToString(*u.IsActive)},
|
||||
"superuser": {BoolToString(u.IsSuperuser)},
|
||||
"goauthentik.io/ldap/active": {BoolToString(*u.IsActive)},
|
||||
"goauthentik.io/ldap/superuser": {BoolToString(u.IsSuperuser)},
|
||||
"cn": {u.Username},
|
||||
"sAMAccountName": {u.Username},
|
||||
"uid": {u.Uid},
|
||||
"name": {u.Name},
|
||||
"displayName": {u.Name},
|
||||
"mail": {*u.Email},
|
||||
"objectClass": {UserObjectClass, "organizationalPerson", "inetOrgPerson", "goauthentik.io/ldap/user"},
|
||||
"uidNumber": {pi.GetUidNumber(u)},
|
||||
"gidNumber": {pi.GetUidNumber(u)},
|
||||
})
|
||||
return &ldap.Entry{DN: dn, Attributes: attrs}
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GroupEntry(g LDAPGroup) *ldap.Entry {
|
||||
attrs := AKAttrsToLDAP(g.akAttributes)
|
||||
|
||||
objectClass := []string{GroupObjectClass, "groupOfUniqueNames", "goauthentik.io/ldap/group"}
|
||||
if g.isVirtualGroup {
|
||||
objectClass = append(objectClass, "goauthentik.io/ldap/virtual-group")
|
||||
}
|
||||
|
||||
attrs = pi.ensureAttributes(attrs, map[string][]string{
|
||||
"objectClass": objectClass,
|
||||
"member": g.member,
|
||||
"goauthentik.io/ldap/superuser": {BoolToString(g.isSuperuser)},
|
||||
"cn": {g.cn},
|
||||
"uid": {g.uid},
|
||||
"sAMAccountName": {g.cn},
|
||||
"gidNumber": {g.gidNumber},
|
||||
})
|
||||
return &ldap.Entry{DN: g.dn, Attributes: attrs}
|
||||
}
|
@ -2,50 +2,18 @@ package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/crypto"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
)
|
||||
|
||||
const GroupObjectClass = "group"
|
||||
const UserObjectClass = "user"
|
||||
|
||||
type ProviderInstance struct {
|
||||
BaseDN string
|
||||
|
||||
UserDN string
|
||||
|
||||
VirtualGroupDN string
|
||||
GroupDN string
|
||||
|
||||
appSlug string
|
||||
flowSlug string
|
||||
s *LDAPServer
|
||||
log *log.Entry
|
||||
|
||||
tlsServerName *string
|
||||
cert *tls.Certificate
|
||||
outpostName string
|
||||
searchAllowedGroups []*strfmt.UUID
|
||||
boundUsersMutex sync.RWMutex
|
||||
boundUsers map[string]UserFlags
|
||||
|
||||
uidStartNumber int32
|
||||
gidStartNumber int32
|
||||
}
|
||||
|
||||
type UserFlags struct {
|
||||
UserInfo *api.User
|
||||
UserPk int32
|
||||
CanSearch bool
|
||||
}
|
||||
|
||||
type LDAPServer struct {
|
||||
s *ldap.Server
|
||||
log *log.Entry
|
||||
@ -55,17 +23,6 @@ type LDAPServer struct {
|
||||
providers []*ProviderInstance
|
||||
}
|
||||
|
||||
type LDAPGroup struct {
|
||||
dn string
|
||||
cn string
|
||||
uid string
|
||||
gidNumber string
|
||||
member []string
|
||||
isSuperuser bool
|
||||
isVirtualGroup bool
|
||||
akAttributes interface{}
|
||||
}
|
||||
|
||||
func NewServer(ac *ak.APIController) *LDAPServer {
|
||||
s := ldap.NewServer()
|
||||
s.EnforceLDAP = true
|
||||
@ -83,10 +40,60 @@ func NewServer(ac *ak.APIController) *LDAPServer {
|
||||
ls.defaultCert = &defaultCert
|
||||
s.BindFunc("", ls)
|
||||
s.SearchFunc("", ls)
|
||||
s.CloseFunc("", ls)
|
||||
return ls
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Type() string {
|
||||
return "ldap"
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) StartLDAPServer() error {
|
||||
listen := "0.0.0.0:3389"
|
||||
|
||||
ln, err := net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
|
||||
}
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
defer proxyListener.Close()
|
||||
|
||||
ls.log.WithField("listen", listen).Info("Starting ldap server")
|
||||
err = ls.s.Serve(proxyListener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ls.log.Printf("closing %s", ln.Addr())
|
||||
return ls.s.ListenAndServe(listen)
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Start() error {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
metrics.RunServer()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := ls.StartLDAPServer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := ls.StartLDAPTLSServer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) TimerFlowCacheExpiry() {
|
||||
for _, p := range ls.providers {
|
||||
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
|
||||
p.binder.TimerFlowCacheExpiry()
|
||||
}
|
||||
}
|
||||
|
55
internal/outpost/ldap/ldap_tls.go
Normal file
55
internal/outpost/ldap/ldap_tls.go
Normal file
@ -0,0 +1,55 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
)
|
||||
|
||||
func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if len(ls.providers) == 1 {
|
||||
if ls.providers[0].cert != nil {
|
||||
ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert")
|
||||
return ls.providers[0].cert, nil
|
||||
}
|
||||
}
|
||||
for _, provider := range ls.providers {
|
||||
if provider.tlsServerName == &info.ServerName {
|
||||
if provider.cert == nil {
|
||||
ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
|
||||
return ls.defaultCert, nil
|
||||
}
|
||||
return provider.cert, nil
|
||||
}
|
||||
}
|
||||
ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert")
|
||||
return ls.defaultCert, nil
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) StartLDAPTLSServer() error {
|
||||
listen := "0.0.0.0:6636"
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
GetCertificate: ls.getCertificates,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
|
||||
}
|
||||
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
defer proxyListener.Close()
|
||||
|
||||
tln := tls.NewListener(proxyListener, tlsConfig)
|
||||
|
||||
ls.log.WithField("listen", listen).Info("Starting ldap tls server")
|
||||
err = ls.s.Serve(tln)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ls.log.Printf("closing %s", ln.Addr())
|
||||
return ls.s.ListenAndServe(listen)
|
||||
}
|
@ -2,23 +2,19 @@ package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
UsersOU = "users"
|
||||
GroupsOU = "groups"
|
||||
VirtualGroupsOU = "virtual-groups"
|
||||
"goauthentik.io/api"
|
||||
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
directsearch "goauthentik.io/internal/outpost/ldap/search/direct"
|
||||
memorysearch "goauthentik.io/internal/outpost/ldap/search/memory"
|
||||
)
|
||||
|
||||
func (ls *LDAPServer) Refresh() error {
|
||||
@ -31,9 +27,9 @@ func (ls *LDAPServer) Refresh() error {
|
||||
}
|
||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
||||
for idx, provider := range outposts.Results {
|
||||
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", UsersOU, *provider.BaseDn))
|
||||
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", GroupsOU, *provider.BaseDn))
|
||||
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", VirtualGroupsOU, *provider.BaseDn))
|
||||
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
|
||||
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
||||
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
|
||||
logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name)
|
||||
providers[idx] = &ProviderInstance{
|
||||
BaseDN: *provider.BaseDn,
|
||||
@ -44,7 +40,7 @@ func (ls *LDAPServer) Refresh() error {
|
||||
flowSlug: provider.BindFlowSlug,
|
||||
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
|
||||
boundUsersMutex: sync.RWMutex{},
|
||||
boundUsers: make(map[string]UserFlags),
|
||||
boundUsers: make(map[string]flags.UserFlags),
|
||||
s: ls,
|
||||
log: logger,
|
||||
tlsServerName: provider.TlsServerName,
|
||||
@ -60,79 +56,14 @@ func (ls *LDAPServer) Refresh() error {
|
||||
}
|
||||
providers[idx].cert = ls.cs.Get(*kp)
|
||||
}
|
||||
if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_CACHED {
|
||||
providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx])
|
||||
} else if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_DIRECT {
|
||||
providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx])
|
||||
}
|
||||
providers[idx].binder = directbind.NewDirectBinder(providers[idx])
|
||||
}
|
||||
ls.providers = providers
|
||||
ls.log.Info("Update providers")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) StartLDAPServer() error {
|
||||
listen := "0.0.0.0:3389"
|
||||
|
||||
ln, err := net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err)
|
||||
}
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
defer proxyListener.Close()
|
||||
|
||||
ls.log.WithField("listen", listen).Info("Starting ldap server")
|
||||
err = ls.s.Serve(proxyListener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ls.log.Printf("closing %s", ln.Addr())
|
||||
return ls.s.ListenAndServe(listen)
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) StartLDAPTLSServer() error {
|
||||
listen := "0.0.0.0:6636"
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
GetCertificate: ls.getCertificates,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err)
|
||||
}
|
||||
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
defer proxyListener.Close()
|
||||
|
||||
tln := tls.NewListener(proxyListener, tlsConfig)
|
||||
|
||||
ls.log.WithField("listen", listen).Info("Starting ldap tls server")
|
||||
err = ls.s.Serve(tln)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ls.log.Printf("closing %s", ln.Addr())
|
||||
return ls.s.ListenAndServe(listen)
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Start() error {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
metrics.RunServer()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := ls.StartLDAPServer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := ls.StartLDAPTLSServer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
@ -1,47 +1,21 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
goldap "github.com/go-ldap/ldap/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
type SearchRequest struct {
|
||||
ldap.SearchRequest
|
||||
BindDN string
|
||||
id string
|
||||
conn net.Conn
|
||||
log *log.Entry
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) {
|
||||
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search"))
|
||||
rid := uuid.New().String()
|
||||
span.SetTag("request_uid", rid)
|
||||
span.SetTag("user.username", bindDN)
|
||||
span.SetTag("ak_filter", searchReq.Filter)
|
||||
span.SetTag("ak_base_dn", searchReq.BaseDN)
|
||||
|
||||
bindDN = strings.ToLower(bindDN)
|
||||
req := SearchRequest{
|
||||
SearchRequest: searchReq,
|
||||
BindDN: bindDN,
|
||||
conn: conn,
|
||||
log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN),
|
||||
id: rid,
|
||||
ctx: span.Context(),
|
||||
}
|
||||
req, span := search.NewRequest(bindDN, searchReq, conn)
|
||||
|
||||
defer func() {
|
||||
span.Finish()
|
||||
@ -50,9 +24,9 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
|
||||
"type": "search",
|
||||
"filter": req.Filter,
|
||||
"dn": req.BindDN,
|
||||
"client": utils.GetIP(req.conn.RemoteAddr()),
|
||||
"client": utils.GetIP(conn.RemoteAddr()),
|
||||
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
|
||||
req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request")
|
||||
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request")
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
@ -69,13 +43,13 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
|
||||
}
|
||||
bd, err := goldap.ParseDN(searchReq.BaseDN)
|
||||
if err != nil {
|
||||
req.log.WithError(err).Info("failed to parse basedn")
|
||||
req.Log().WithError(err).Info("failed to parse basedn")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("invalid DN")
|
||||
}
|
||||
for _, provider := range ls.providers {
|
||||
providerBase, _ := goldap.ParseDN(provider.BaseDN)
|
||||
if providerBase.AncestorOf(bd) || providerBase.Equal(bd) {
|
||||
return provider.Search(req)
|
||||
return provider.searcher.Search(req)
|
||||
}
|
||||
}
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("no provider could handle request")
|
||||
|
@ -1,13 +1,14 @@
|
||||
package ldap
|
||||
package direct
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
)
|
||||
|
||||
func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.ServerSearchResult, error) {
|
||||
func (ds *DirectSearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
|
||||
dn := ""
|
||||
if authz {
|
||||
dn = req.SearchRequest.BaseDN
|
||||
@ -19,7 +20,7 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{
|
||||
Name: "distinguishedName",
|
||||
Values: []string{pi.BaseDN},
|
||||
Values: []string{ds.si.GetBaseDN()},
|
||||
},
|
||||
{
|
||||
Name: "objectClass",
|
||||
@ -32,9 +33,9 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv
|
||||
{
|
||||
Name: "namingContexts",
|
||||
Values: []string{
|
||||
pi.BaseDN,
|
||||
pi.GroupDN,
|
||||
pi.UserDN,
|
||||
ds.si.GetBaseDN(),
|
||||
ds.si.GetBaseUserDN(),
|
||||
ds.si.GetBaseGroupDN(),
|
||||
},
|
||||
},
|
||||
{
|
219
internal/outpost/ldap/search/direct/direct.go
Normal file
219
internal/outpost/ldap/search/direct/direct.go
Normal file
@ -0,0 +1,219 @@
|
||||
package direct
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
)
|
||||
|
||||
type DirectSearcher struct {
|
||||
si server.LDAPServerInstance
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
func NewDirectSearcher(si server.LDAPServerInstance) *DirectSearcher {
|
||||
ds := &DirectSearcher{
|
||||
si: si,
|
||||
log: log.WithField("logger", "authentik.outpost.ldap.searcher.direct"),
|
||||
}
|
||||
ds.log.Info("initialised direct searcher")
|
||||
return ds
|
||||
}
|
||||
|
||||
func (ds *DirectSearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
|
||||
if f.UserInfo == nil {
|
||||
u, _, err := ds.si.GetAPIClient().CoreApi.CoreUsersRetrieve(req.Context(), f.UserPk).Execute()
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("Failed to get user info")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
f.UserInfo = &u
|
||||
}
|
||||
entries := make([]*ldap.Entry, 1)
|
||||
entries[0] = ds.si.UserEntry(*f.UserInfo)
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
||||
func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
|
||||
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
|
||||
baseDN := strings.ToLower("," + ds.si.GetBaseDN())
|
||||
|
||||
entries := []*ldap.Entry{}
|
||||
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "filter_parse_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
}
|
||||
if len(req.BindDN) < 1 {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "empty_bind_dn",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
|
||||
}
|
||||
if !strings.HasSuffix(req.BindDN, baseDN) {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "invalid_bind_dn",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ds.si.GetBaseDN())
|
||||
}
|
||||
|
||||
flags, ok := ds.si.GetFlags(req.BindDN)
|
||||
if !ok {
|
||||
req.Log().Debug("User info not cached")
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "user_info_not_cached",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
|
||||
}
|
||||
|
||||
if req.Scope == ldap.ScopeBaseObject {
|
||||
req.Log().Debug("base scope, showing domain info")
|
||||
return ds.SearchBase(req, flags.CanSearch)
|
||||
}
|
||||
if !flags.CanSearch {
|
||||
req.Log().Debug("User can't search, showing info about user")
|
||||
return ds.SearchMe(req, flags)
|
||||
}
|
||||
accsp.Finish()
|
||||
|
||||
parsedFilter, err := ldap.CompileFilter(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "filter_parse_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
}
|
||||
|
||||
// Create a custom client to set additional headers
|
||||
c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig())
|
||||
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
|
||||
|
||||
switch filterEntity {
|
||||
default:
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "unhandled_filter_type",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
|
||||
case constants.OCGroupOfUniqueNames:
|
||||
fallthrough
|
||||
case constants.OCAKGroup:
|
||||
fallthrough
|
||||
case constants.OCAKVirtualGroup:
|
||||
fallthrough
|
||||
case constants.OCGroup:
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
gEntries := make([]*ldap.Entry, 0)
|
||||
uEntries := make([]*ldap.Entry, 0)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
gapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_group")
|
||||
searchReq, skip := utils.ParseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return
|
||||
}
|
||||
groups, _, err := searchReq.Execute()
|
||||
gapisp.Finish()
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("failed to get groups")
|
||||
return
|
||||
}
|
||||
req.Log().WithField("count", len(groups.Results)).Trace("Got results from API")
|
||||
|
||||
for _, g := range groups.Results {
|
||||
gEntries = append(gEntries, group.FromAPIGroup(g, ds.si).Entry())
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return
|
||||
}
|
||||
users, _, err := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("failed to get users")
|
||||
return
|
||||
}
|
||||
|
||||
for _, u := range users.Results {
|
||||
uEntries = append(uEntries, group.FromAPIUser(u, ds.si).Entry())
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
entries = append(gEntries, uEntries...)
|
||||
case "":
|
||||
fallthrough
|
||||
case constants.OCOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCInetOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCAKUser:
|
||||
fallthrough
|
||||
case constants.OCUser:
|
||||
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
users, _, err := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
|
||||
}
|
||||
for _, u := range users.Results {
|
||||
entries = append(entries, ds.si.UserEntry(u))
|
||||
}
|
||||
}
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
54
internal/outpost/ldap/search/memory/base.go
Normal file
54
internal/outpost/ldap/search/memory/base.go
Normal file
@ -0,0 +1,54 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
)
|
||||
|
||||
func (ms *MemorySearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
|
||||
dn := ""
|
||||
if authz {
|
||||
dn = req.SearchRequest.BaseDN
|
||||
}
|
||||
return ldap.ServerSearchResult{
|
||||
Entries: []*ldap.Entry{
|
||||
{
|
||||
DN: dn,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{
|
||||
Name: "distinguishedName",
|
||||
Values: []string{ms.si.GetBaseDN()},
|
||||
},
|
||||
{
|
||||
Name: "objectClass",
|
||||
Values: []string{"top", "domain"},
|
||||
},
|
||||
{
|
||||
Name: "supportedLDAPVersion",
|
||||
Values: []string{"3"},
|
||||
},
|
||||
{
|
||||
Name: "namingContexts",
|
||||
Values: []string{
|
||||
ms.si.GetBaseDN(),
|
||||
ms.si.GetBaseUserDN(),
|
||||
ms.si.GetBaseGroupDN(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vendorName",
|
||||
Values: []string{"goauthentik.io"},
|
||||
},
|
||||
{
|
||||
Name: "vendorVersion",
|
||||
Values: []string{fmt.Sprintf("authentik LDAP Outpost Version %s (build %s)", constants.VERSION, constants.BUILD())},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess,
|
||||
}, nil
|
||||
}
|
63
internal/outpost/ldap/search/memory/fetch.go
Normal file
63
internal/outpost/ldap/search/memory/fetch.go
Normal file
@ -0,0 +1,63 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"goauthentik.io/api"
|
||||
)
|
||||
|
||||
const pageSize = 100
|
||||
|
||||
func (ms *MemorySearcher) FetchUsers() []api.User {
|
||||
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
|
||||
users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
|
||||
if err != nil {
|
||||
ms.log.WithError(err).Warning("failed to update users")
|
||||
return nil, err
|
||||
}
|
||||
ms.log.WithField("page", page).Debug("fetched users")
|
||||
return &users, nil
|
||||
}
|
||||
page := 1
|
||||
users := make([]api.User, 0)
|
||||
for {
|
||||
apiUsers, err := fetchUsersOffset(page)
|
||||
if err != nil {
|
||||
return users
|
||||
}
|
||||
if apiUsers.Pagination.Next > 0 {
|
||||
page += 1
|
||||
} else {
|
||||
break
|
||||
}
|
||||
users = append(users, apiUsers.Results...)
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func (ms *MemorySearcher) FetchGroups() []api.Group {
|
||||
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
|
||||
groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
|
||||
if err != nil {
|
||||
ms.log.WithError(err).Warning("failed to update groups")
|
||||
return nil, err
|
||||
}
|
||||
ms.log.WithField("page", page).Debug("fetched groups")
|
||||
return &groups, nil
|
||||
}
|
||||
page := 1
|
||||
groups := make([]api.Group, 0)
|
||||
for {
|
||||
apiGroups, err := fetchGroupsOffset(page)
|
||||
if err != nil {
|
||||
return groups
|
||||
}
|
||||
if apiGroups.Pagination.Next > 0 {
|
||||
page += 1
|
||||
} else {
|
||||
break
|
||||
}
|
||||
groups = append(groups, apiGroups.Results...)
|
||||
}
|
||||
return groups
|
||||
}
|
182
internal/outpost/ldap/search/memory/memory.go
Normal file
182
internal/outpost/ldap/search/memory/memory.go
Normal file
@ -0,0 +1,182 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
)
|
||||
|
||||
type MemorySearcher struct {
|
||||
si server.LDAPServerInstance
|
||||
log *log.Entry
|
||||
|
||||
users []api.User
|
||||
groups []api.Group
|
||||
}
|
||||
|
||||
func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
|
||||
ms := &MemorySearcher{
|
||||
si: si,
|
||||
log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
|
||||
}
|
||||
ms.log.Info("initialised memory searcher")
|
||||
ms.users = ms.FetchUsers()
|
||||
ms.groups = ms.FetchGroups()
|
||||
return ms
|
||||
}
|
||||
|
||||
func (ms *MemorySearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
|
||||
if f.UserInfo == nil {
|
||||
for _, u := range ms.users {
|
||||
if u.Pk == f.UserPk {
|
||||
f.UserInfo = &u
|
||||
}
|
||||
}
|
||||
if f.UserInfo == nil {
|
||||
req.Log().WithField("pk", f.UserPk).Warning("User with pk is not in local cache")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
}
|
||||
entries := make([]*ldap.Entry, 1)
|
||||
entries[0] = ms.si.UserEntry(*f.UserInfo)
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
||||
func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
|
||||
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
|
||||
baseDN := strings.ToLower("," + ms.si.GetBaseDN())
|
||||
|
||||
entries := []*ldap.Entry{}
|
||||
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "filter_parse_fail",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
}
|
||||
if len(req.BindDN) < 1 {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "empty_bind_dn",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
|
||||
}
|
||||
if !strings.HasSuffix(req.BindDN, baseDN) {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "invalid_bind_dn",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN())
|
||||
}
|
||||
|
||||
flags, ok := ms.si.GetFlags(req.BindDN)
|
||||
if !ok {
|
||||
req.Log().Debug("User info not cached")
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "user_info_not_cached",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
|
||||
}
|
||||
|
||||
if req.Scope == ldap.ScopeBaseObject {
|
||||
req.Log().Debug("base scope, showing domain info")
|
||||
return ms.SearchBase(req, flags.CanSearch)
|
||||
}
|
||||
if !flags.CanSearch {
|
||||
req.Log().Debug("User can't search, showing info about user")
|
||||
return ms.SearchMe(req, flags)
|
||||
}
|
||||
accsp.Finish()
|
||||
|
||||
// parsedFilter, err := ldap.CompileFilter(req.Filter)
|
||||
// if err != nil {
|
||||
// metrics.RequestsRejected.With(prometheus.Labels{
|
||||
// "outpost_name": ms.si.GetOutpostName(),
|
||||
// "type": "search",
|
||||
// "reason": "filter_parse_fail",
|
||||
// "dn": req.BindDN,
|
||||
// "client": req.RemoteAddr(),
|
||||
// }).Inc()
|
||||
// return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
// }
|
||||
|
||||
switch filterEntity {
|
||||
default:
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "unhandled_filter_type",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
|
||||
case constants.OCGroupOfUniqueNames:
|
||||
fallthrough
|
||||
case constants.OCAKGroup:
|
||||
fallthrough
|
||||
case constants.OCAKVirtualGroup:
|
||||
fallthrough
|
||||
case constants.OCGroup:
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
gEntries := make([]*ldap.Entry, 0)
|
||||
uEntries := make([]*ldap.Entry, 0)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for _, g := range ms.groups {
|
||||
gEntries = append(gEntries, group.FromAPIGroup(g, ms.si).Entry())
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for _, u := range ms.users {
|
||||
uEntries = append(uEntries, group.FromAPIUser(u, ms.si).Entry())
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
entries = append(gEntries, uEntries...)
|
||||
case "":
|
||||
fallthrough
|
||||
case constants.OCOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCInetOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCAKUser:
|
||||
fallthrough
|
||||
case constants.OCUser:
|
||||
for _, u := range ms.users {
|
||||
entries = append(entries, ms.si.UserEntry(u))
|
||||
}
|
||||
}
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
53
internal/outpost/ldap/search/request.go
Normal file
53
internal/outpost/ldap/search/request.go
Normal file
@ -0,0 +1,53 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nmcclain/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
ldap.SearchRequest
|
||||
BindDN string
|
||||
log *log.Entry
|
||||
|
||||
id string
|
||||
conn net.Conn
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewRequest(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (*Request, *sentry.Span) {
|
||||
rid := uuid.New().String()
|
||||
bindDN = strings.ToLower(bindDN)
|
||||
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search"))
|
||||
span.SetTag("request_uid", rid)
|
||||
span.SetTag("user.username", bindDN)
|
||||
span.SetTag("ak_filter", searchReq.Filter)
|
||||
span.SetTag("ak_base_dn", searchReq.BaseDN)
|
||||
return &Request{
|
||||
SearchRequest: searchReq,
|
||||
BindDN: bindDN,
|
||||
conn: conn,
|
||||
log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN),
|
||||
id: rid,
|
||||
ctx: span.Context(),
|
||||
}, span
|
||||
}
|
||||
|
||||
func (r *Request) Context() context.Context {
|
||||
return r.ctx
|
||||
}
|
||||
|
||||
func (r *Request) Log() *log.Entry {
|
||||
return r.log
|
||||
}
|
||||
|
||||
func (r *Request) RemoteAddr() string {
|
||||
return utils.GetIP(r.conn.RemoteAddr())
|
||||
}
|
7
internal/outpost/ldap/search/searcher.go
Normal file
7
internal/outpost/ldap/search/searcher.go
Normal file
@ -0,0 +1,7 @@
|
||||
package search
|
||||
|
||||
import "github.com/nmcclain/ldap"
|
||||
|
||||
type Searcher interface {
|
||||
Search(req *Request) (ldap.ServerSearchResult, error)
|
||||
}
|
35
internal/outpost/ldap/server/base.go
Normal file
35
internal/outpost/ldap/server/base.go
Normal file
@ -0,0 +1,35 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
)
|
||||
|
||||
type LDAPServerInstance interface {
|
||||
GetAPIClient() *api.APIClient
|
||||
GetOutpostName() string
|
||||
|
||||
GetFlowSlug() string
|
||||
GetAppSlug() string
|
||||
GetSearchAllowedGroups() []*strfmt.UUID
|
||||
|
||||
UserEntry(u api.User) *ldap.Entry
|
||||
|
||||
GetBaseDN() string
|
||||
GetBaseGroupDN() string
|
||||
GetBaseUserDN() string
|
||||
|
||||
GetUserDN(string) string
|
||||
GetGroupDN(string) string
|
||||
GetVirtualGroupDN(string) string
|
||||
|
||||
GetUidNumber(api.User) string
|
||||
GetGidNumber(api.Group) string
|
||||
|
||||
UsersForGroup(api.Group) []string
|
||||
|
||||
GetFlags(string) (flags.UserFlags, bool)
|
||||
SetFlags(string, flags.UserFlags)
|
||||
}
|
@ -3,70 +3,12 @@ package ldap
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
)
|
||||
|
||||
func BoolToString(in bool) string {
|
||||
if in {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
func ldapResolveTypeSingle(in interface{}) *string {
|
||||
switch t := in.(type) {
|
||||
case string:
|
||||
return &t
|
||||
case *string:
|
||||
return t
|
||||
case bool:
|
||||
s := BoolToString(t)
|
||||
return &s
|
||||
case *bool:
|
||||
s := BoolToString(*t)
|
||||
return &s
|
||||
default:
|
||||
log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute {
|
||||
attrList := []*ldap.EntryAttribute{}
|
||||
if attrs == nil {
|
||||
return attrList
|
||||
}
|
||||
a := attrs.(*map[string]interface{})
|
||||
for attrKey, attrValue := range *a {
|
||||
entry := &ldap.EntryAttribute{Name: attrKey}
|
||||
switch t := attrValue.(type) {
|
||||
case []string:
|
||||
entry.Values = t
|
||||
case *[]string:
|
||||
entry.Values = *t
|
||||
case []interface{}:
|
||||
entry.Values = make([]string, len(t))
|
||||
for idx, v := range t {
|
||||
v := ldapResolveTypeSingle(v)
|
||||
entry.Values[idx] = *v
|
||||
}
|
||||
default:
|
||||
v := ldapResolveTypeSingle(t)
|
||||
if v != nil {
|
||||
entry.Values = []string{*v}
|
||||
}
|
||||
}
|
||||
attrList = append(attrList, entry)
|
||||
}
|
||||
return attrList
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GroupsForUser(user api.User) []string {
|
||||
groups := make([]string, len(user.Groups))
|
||||
for i, group := range user.GroupsObj {
|
||||
@ -83,32 +25,6 @@ func (pi *ProviderInstance) UsersForGroup(group api.Group) []string {
|
||||
return users
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) APIGroupToLDAPGroup(g api.Group) LDAPGroup {
|
||||
return LDAPGroup{
|
||||
dn: pi.GetGroupDN(g.Name),
|
||||
cn: g.Name,
|
||||
uid: string(g.Pk),
|
||||
gidNumber: pi.GetGidNumber(g),
|
||||
member: pi.UsersForGroup(g),
|
||||
isVirtualGroup: false,
|
||||
isSuperuser: *g.IsSuperuser,
|
||||
akAttributes: g.Attributes,
|
||||
}
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) APIUserToLDAPGroup(u api.User) LDAPGroup {
|
||||
return LDAPGroup{
|
||||
dn: pi.GetVirtualGroupDN(u.Username),
|
||||
cn: u.Username,
|
||||
uid: u.Uid,
|
||||
gidNumber: pi.GetUidNumber(u),
|
||||
member: []string{pi.GetUserDN(u.Username)},
|
||||
isVirtualGroup: true,
|
||||
isSuperuser: false,
|
||||
akAttributes: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetUserDN(user string) string {
|
||||
return fmt.Sprintf("cn=%s,%s", user, pi.UserDN)
|
||||
}
|
||||
@ -155,26 +71,3 @@ func (pi *ProviderInstance) GetRIDForGroup(uid string) int32 {
|
||||
|
||||
return int32(gid)
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) ensureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute {
|
||||
for name, values := range shouldHave {
|
||||
attrs = pi.mustHaveAttribute(attrs, name, values)
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) mustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute {
|
||||
shouldSet := true
|
||||
for _, attr := range attrs {
|
||||
if attr.Name == name {
|
||||
shouldSet = false
|
||||
}
|
||||
}
|
||||
if shouldSet {
|
||||
return append(attrs, &ldap.EntryAttribute{
|
||||
Name: name,
|
||||
Values: value,
|
||||
})
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
86
internal/outpost/ldap/utils/utils.go
Normal file
86
internal/outpost/ldap/utils/utils.go
Normal file
@ -0,0 +1,86 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func BoolToString(in bool) string {
|
||||
if in {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
func ldapResolveTypeSingle(in interface{}) *string {
|
||||
switch t := in.(type) {
|
||||
case string:
|
||||
return &t
|
||||
case *string:
|
||||
return t
|
||||
case bool:
|
||||
s := BoolToString(t)
|
||||
return &s
|
||||
case *bool:
|
||||
s := BoolToString(*t)
|
||||
return &s
|
||||
default:
|
||||
log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute {
|
||||
attrList := []*ldap.EntryAttribute{}
|
||||
if attrs == nil {
|
||||
return attrList
|
||||
}
|
||||
a := attrs.(*map[string]interface{})
|
||||
for attrKey, attrValue := range *a {
|
||||
entry := &ldap.EntryAttribute{Name: attrKey}
|
||||
switch t := attrValue.(type) {
|
||||
case []string:
|
||||
entry.Values = t
|
||||
case *[]string:
|
||||
entry.Values = *t
|
||||
case []interface{}:
|
||||
entry.Values = make([]string, len(t))
|
||||
for idx, v := range t {
|
||||
v := ldapResolveTypeSingle(v)
|
||||
entry.Values[idx] = *v
|
||||
}
|
||||
default:
|
||||
v := ldapResolveTypeSingle(t)
|
||||
if v != nil {
|
||||
entry.Values = []string{*v}
|
||||
}
|
||||
}
|
||||
attrList = append(attrList, entry)
|
||||
}
|
||||
return attrList
|
||||
}
|
||||
|
||||
func EnsureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute {
|
||||
for name, values := range shouldHave {
|
||||
attrs = MustHaveAttribute(attrs, name, values)
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func MustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute {
|
||||
shouldSet := true
|
||||
for _, attr := range attrs {
|
||||
if attr.Name == name {
|
||||
shouldSet = false
|
||||
}
|
||||
}
|
||||
if shouldSet {
|
||||
return append(attrs, &ldap.EntryAttribute{
|
||||
Name: name,
|
||||
Values: value,
|
||||
})
|
||||
}
|
||||
return attrs
|
||||
}
|
@ -1,19 +1,20 @@
|
||||
package ldap
|
||||
package utils
|
||||
|
||||
import (
|
||||
goldap "github.com/go-ldap/ldap/v3"
|
||||
ber "github.com/nmcclain/asn1-ber"
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
)
|
||||
|
||||
func parseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) {
|
||||
func ParseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) {
|
||||
switch f.Tag {
|
||||
case ldap.FilterEqualityMatch:
|
||||
return parseFilterForGroupSingle(req, f)
|
||||
case ldap.FilterAnd:
|
||||
for _, child := range f.Children {
|
||||
r, s := parseFilterForGroup(req, child, skip)
|
||||
r, s := ParseFilterForGroup(req, child, skip)
|
||||
skip = skip || s
|
||||
req = r
|
||||
}
|
||||
@ -53,7 +54,7 @@ func parseFilterForGroupSingle(req api.ApiCoreGroupsListRequest, f *ber.Packet)
|
||||
username := userDN.RDNs[0].Attributes[0].Value
|
||||
// If the DN's first ou is virtual-groups, ignore this filter
|
||||
if len(userDN.RDNs) > 1 {
|
||||
if userDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU || userDN.RDNs[1].Attributes[0].Value == GroupsOU {
|
||||
if userDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups || userDN.RDNs[1].Attributes[0].Value == constants.OUGroups {
|
||||
// Since we know we're not filtering anything, skip this request
|
||||
return req, true
|
||||
}
|
@ -1,19 +1,20 @@
|
||||
package ldap
|
||||
package utils
|
||||
|
||||
import (
|
||||
goldap "github.com/go-ldap/ldap/v3"
|
||||
ber "github.com/nmcclain/asn1-ber"
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
)
|
||||
|
||||
func parseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) {
|
||||
func ParseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) {
|
||||
switch f.Tag {
|
||||
case ldap.FilterEqualityMatch:
|
||||
return parseFilterForUserSingle(req, f)
|
||||
case ldap.FilterAnd:
|
||||
for _, child := range f.Children {
|
||||
r, s := parseFilterForUser(req, child, skip)
|
||||
r, s := ParseFilterForUser(req, child, skip)
|
||||
skip = skip || s
|
||||
req = r
|
||||
}
|
||||
@ -58,7 +59,7 @@ func parseFilterForUserSingle(req api.ApiCoreUsersListRequest, f *ber.Packet) (a
|
||||
name := groupDN.RDNs[0].Attributes[0].Value
|
||||
// If the DN's first ou is virtual-groups, ignore this filter
|
||||
if len(groupDN.RDNs) > 1 {
|
||||
if groupDN.RDNs[1].Attributes[0].Value == UsersOU || groupDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU {
|
||||
if groupDN.RDNs[1].Attributes[0].Value == constants.OUUsers || groupDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups {
|
||||
// Since we know we're not filtering anything, skip this request
|
||||
return req, true
|
||||
}
|
52
internal/web/sentry_proxy.go
Normal file
52
internal/web/sentry_proxy.go
Normal file
@ -0,0 +1,52 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"goauthentik.io/internal/config"
|
||||
)
|
||||
|
||||
type SentryRequest struct {
|
||||
DSN string `json:"dsn"`
|
||||
}
|
||||
|
||||
func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) {
|
||||
if !config.G.ErrorReporting.Enabled {
|
||||
ws.log.Debug("error reporting disabled")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fullBody, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
ws.log.Debug("failed to read body")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
lines := strings.Split(string(fullBody), "\n")
|
||||
if len(lines) < 1 {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sd := SentryRequest{}
|
||||
err = json.Unmarshal([]byte(lines[0]), &sd)
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to parse sentry request")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if sd.DSN != config.G.ErrorReporting.DSN {
|
||||
ws.log.WithField("have", sd.DSN).WithField("expected", config.G.ErrorReporting.DSN).Debug("invalid DSN")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
res, err := http.DefaultClient.Post("https://sentry.beryju.org/api/8/envelope/", "application/octet-stream", strings.NewReader(string(fullBody)))
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Warning("failed to proxy sentry")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
}
|
@ -26,7 +26,7 @@ func (ws *WebServer) listenTLS() {
|
||||
ws.log.WithError(err).Fatalf("failed to listen")
|
||||
return
|
||||
}
|
||||
ws.log.WithField("addr", config.G.Web.ListenTLS).Info("Running")
|
||||
ws.log.WithField("addr", config.G.Web.ListenTLS).Info("Listening (TLS)")
|
||||
|
||||
proxyListener := &proxyproto.Listener{Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}}
|
||||
defer proxyListener.Close()
|
||||
|
@ -51,10 +51,15 @@ func NewWebServer(g *gounicorn.GoUnicorn) *WebServer {
|
||||
p: g,
|
||||
}
|
||||
ws.configureStatic()
|
||||
ws.configureRoutes()
|
||||
ws.configureProxy()
|
||||
return ws
|
||||
}
|
||||
|
||||
func (ws *WebServer) configureRoutes() {
|
||||
ws.m.Path("/api/v3/sentry/").HandlerFunc(ws.APISentryProxy)
|
||||
}
|
||||
|
||||
func (ws *WebServer) Start() {
|
||||
go ws.listenPlain()
|
||||
go ws.listenTLS()
|
||||
@ -69,14 +74,13 @@ func (ws *WebServer) listenPlain() {
|
||||
if err != nil {
|
||||
ws.log.WithError(err).Fatalf("failed to listen")
|
||||
}
|
||||
ws.log.WithField("addr", config.G.Web.Listen).Info("Running")
|
||||
ws.log.WithField("addr", config.G.Web.Listen).Info("Listening")
|
||||
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
defer proxyListener.Close()
|
||||
|
||||
ws.serve(proxyListener)
|
||||
|
||||
ws.log.WithField("addr", config.G.Web.Listen).Info("Running")
|
||||
err = http.ListenAndServe(config.G.Web.Listen, ws.m)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
ws.log.Errorf("ERROR: http.Serve() - %s", err)
|
||||
|
@ -1,14 +1,14 @@
|
||||
# Stage 1: Build
|
||||
FROM docker.io/golang:1.17.2 AS builder
|
||||
FROM docker.io/golang:1.17.3-bullseye AS builder
|
||||
|
||||
WORKDIR /go/src/goauthentik.io
|
||||
|
||||
COPY . .
|
||||
|
||||
ENV CGO_ENABLED=0
|
||||
RUN go build -o /go/ldap ./cmd/ldap
|
||||
|
||||
# Stage 2: Run
|
||||
FROM gcr.io/distroless/base-debian10:debug
|
||||
FROM gcr.io/distroless/static-debian11:debug
|
||||
|
||||
ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user